{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e7b4324e",
   "metadata": {},
   "source": [
    "# Time Series Clustering using DTW KMeans\n",
    "\n",
    "In this notebook, we will train a KMeans Clustering algorithm based on DTW distances between Time Series data. \n",
    "\n",
    "We leverage the [tslearn library](https://tslearn.readthedocs.io/en/stable/index.html) for clustering. The data used in this analysis is publicly available via UCI Archive under [Online Retail II Data Set](https://archive.ics.uci.edu/ml/datasets/Online+Retail+II).\n",
    "\n",
    "We have cleaned and preprocessed this dataset in the optional notebook: 01. Optional - Data Cleaning and Preparation. The reader may directly use the preprocessed data included in the repository under: `./data/df_pivoted.zip` for running this notebook.\n",
    "\n",
    "Tested with Python3, Pandas version 1.0.5.\n",
    "\n",
    "*References*\n",
    " * Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science.\n",
    " * Direct link to UCI dataset: https://archive.ics.uci.edu/ml/machine-learning-databases/00502/online_retail_II.xlsx\n",
    " * tslearn github repo: https://github.com/tslearn-team/tslearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fcbf3263",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6a6a7fa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# optional - suppress warnings\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7721c77",
   "metadata": {},
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a25451d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_pivoted = pd.read_csv('./data/df_pivoted.zip', low_memory=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "48ae4f68",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4605, 374) Index(['2009-12-01 00:00:00', '2009-12-02 00:00:00', '2009-12-03 00:00:00',\n",
      "       '2009-12-04 00:00:00', '2009-12-05 00:00:00', '2009-12-06 00:00:00',\n",
      "       '2009-12-07 00:00:00', '2009-12-08 00:00:00', '2009-12-09 00:00:00',\n",
      "       '2009-12-10 00:00:00',\n",
      "       ...\n",
      "       '2010-11-30 00:00:00', '2010-12-01 00:00:00', '2010-12-02 00:00:00',\n",
      "       '2010-12-03 00:00:00', '2010-12-04 00:00:00', '2010-12-05 00:00:00',\n",
      "       '2010-12-06 00:00:00', '2010-12-07 00:00:00', '2010-12-08 00:00:00',\n",
      "       '2010-12-09 00:00:00'],\n",
      "      dtype='object', length=374)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>2009-12-01 00:00:00</th>\n",
       "      <th>2009-12-02 00:00:00</th>\n",
       "      <th>2009-12-03 00:00:00</th>\n",
       "      <th>2009-12-04 00:00:00</th>\n",
       "      <th>2009-12-05 00:00:00</th>\n",
       "      <th>2009-12-06 00:00:00</th>\n",
       "      <th>2009-12-07 00:00:00</th>\n",
       "      <th>2009-12-08 00:00:00</th>\n",
       "      <th>2009-12-09 00:00:00</th>\n",
       "      <th>2009-12-10 00:00:00</th>\n",
       "      <th>...</th>\n",
       "      <th>2010-11-30 00:00:00</th>\n",
       "      <th>2010-12-01 00:00:00</th>\n",
       "      <th>2010-12-02 00:00:00</th>\n",
       "      <th>2010-12-03 00:00:00</th>\n",
       "      <th>2010-12-04 00:00:00</th>\n",
       "      <th>2010-12-05 00:00:00</th>\n",
       "      <th>2010-12-06 00:00:00</th>\n",
       "      <th>2010-12-07 00:00:00</th>\n",
       "      <th>2010-12-08 00:00:00</th>\n",
       "      <th>2010-12-09 00:00:00</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>StockCode</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>10002</th>\n",
       "      <td>12.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>73.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>49.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>...</td>\n",
       "      <td>12.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>44.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10080</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10109</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10120</th>\n",
       "      <td>60.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>6.0</td>\n",
       "      <td>...</td>\n",
       "      <td>10.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>12.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10125</th>\n",
       "      <td>5.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>46.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>40.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 374 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "           2009-12-01 00:00:00  2009-12-02 00:00:00  2009-12-03 00:00:00  \\\n",
       "StockCode                                                                  \n",
       "10002                     12.0                  0.0                  7.0   \n",
       "10080                      0.0                  1.0                  0.0   \n",
       "10109                      0.0                  0.0                  4.0   \n",
       "10120                     60.0                 10.0                  0.0   \n",
       "10125                      5.0                  0.0                  0.0   \n",
       "\n",
       "           2009-12-04 00:00:00  2009-12-05 00:00:00  2009-12-06 00:00:00  \\\n",
       "StockCode                                                                  \n",
       "10002                     73.0                  0.0                 49.0   \n",
       "10080                      3.0                  0.0                  0.0   \n",
       "10109                      0.0                  0.0                  0.0   \n",
       "10120                     30.0                  0.0                  0.0   \n",
       "10125                     46.0                  0.0                  8.0   \n",
       "\n",
       "           2009-12-07 00:00:00  2009-12-08 00:00:00  2009-12-09 00:00:00  \\\n",
       "StockCode                                                                  \n",
       "10002                      2.0                 12.0                  0.0   \n",
       "10080                      0.0                  0.0                  0.0   \n",
       "10109                      0.0                  0.0                  0.0   \n",
       "10120                      0.0                  1.0                  1.0   \n",
       "10125                     20.0                  1.0                 22.0   \n",
       "\n",
       "           2009-12-10 00:00:00  ...  2010-11-30 00:00:00  2010-12-01 00:00:00  \\\n",
       "StockCode                       ...                                             \n",
       "10002                      1.0  ...                 12.0                 60.0   \n",
       "10080                      0.0  ...                  0.0                  0.0   \n",
       "10109                      0.0  ...                  0.0                  0.0   \n",
       "10120                      6.0  ...                 10.0                  0.0   \n",
       "10125                      0.0  ...                  0.0                  2.0   \n",
       "\n",
       "           2010-12-02 00:00:00  2010-12-03 00:00:00  2010-12-04 00:00:00  \\\n",
       "StockCode                                                                  \n",
       "10002                      1.0                  8.0                  0.0   \n",
       "10080                      0.0                  0.0                  0.0   \n",
       "10109                      0.0                  0.0                  0.0   \n",
       "10120                      0.0                  3.0                  0.0   \n",
       "10125                      0.0                  0.0                  0.0   \n",
       "\n",
       "           2010-12-05 00:00:00  2010-12-06 00:00:00  2010-12-07 00:00:00  \\\n",
       "StockCode                                                                  \n",
       "10002                      1.0                 25.0                  8.0   \n",
       "10080                      0.0                  0.0                  0.0   \n",
       "10109                      0.0                  0.0                  0.0   \n",
       "10120                      0.0                  0.0                  0.0   \n",
       "10125                      0.0                  3.0                  0.0   \n",
       "\n",
       "           2010-12-08 00:00:00  2010-12-09 00:00:00  \n",
       "StockCode                                            \n",
       "10002                     13.0                 44.0  \n",
       "10080                      0.0                  0.0  \n",
       "10109                      0.0                  0.0  \n",
       "10120                     12.0                  0.0  \n",
       "10125                     40.0                  0.0  \n",
       "\n",
       "[5 rows x 374 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# prepare data to laod to tslearn time_series_dataset object\n",
    "df_pivoted.set_index('StockCode', inplace=True)\n",
    "\n",
    "print(df_pivoted.shape, df_pivoted.columns)\n",
    "\n",
    "df_pivoted.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e522195a",
   "metadata": {},
   "source": [
    "## DTW KMeans Clustering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d96ebce3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n",
      "Requirement already satisfied: tslearn in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (0.5.2)\n",
      "Requirement already satisfied: scipy in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (from tslearn) (1.5.3)\n",
      "Requirement already satisfied: joblib in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (from tslearn) (1.1.0)\n",
      "Requirement already satisfied: numba in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (from tslearn) (0.54.1)\n",
      "Requirement already satisfied: numpy in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (from tslearn) (1.20.3)\n",
      "Requirement already satisfied: Cython in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (from tslearn) (0.29.24)\n",
      "Requirement already satisfied: scikit-learn in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (from tslearn) (1.0.1)\n",
      "Requirement already satisfied: setuptools in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (from numba->tslearn) (59.4.0)\n",
      "Requirement already satisfied: llvmlite<0.38,>=0.37.0rc1 in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (from numba->tslearn) (0.37.0)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/ec2-user/anaconda3/envs/python3/lib/python3.8/site-packages (from scikit-learn->tslearn) (3.0.0)\n",
      "\u001b[33mWARNING: You are using pip version 22.0.4; however, version 22.3.1 is available.\n",
      "You should consider upgrading via the '/home/ec2-user/anaconda3/envs/python3/bin/python -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install tslearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "04aaa397",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tslearn.utils import to_time_series_dataset\n",
    "from tslearn.preprocessing import TimeSeriesScalerMeanVariance\n",
    "from tslearn.clustering import TimeSeriesKMeans, silhouette_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2ce08d97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4605, 374, 1) (4605, 374, 1)\n",
      "CPU times: user 334 ms, sys: 34.4 ms, total: 368 ms\n",
      "Wall time: 355 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "# convert dataframe to time_series_dataset\n",
    "X = to_time_series_dataset(df_pivoted.values)\n",
    "\n",
    "# normalize time series to zero mean and unit variance\n",
    "X_train = TimeSeriesScalerMeanVariance().fit_transform(X)\n",
    "\n",
    "print(X.shape, X_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "362fc965",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create required directory structure\n",
    "dir_paths = ['./tsl', './tsl/models', './tsl/plots']\n",
    "\n",
    "for dir_path in dir_paths:\n",
    "    if not os.path.exists(dir_path):\n",
    "        os.makedirs(dir_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee9e7783",
   "metadata": {},
   "source": [
    "### Perform clustering\n",
    "\n",
    "With 12 cores, the clustering operation can take roughly half an hour. Your mileage may vary depending on your machine's configuration. `n_jobs = -1` ensures that the training uses all available cores on your machine.\n",
    "\n",
    "Other alternative to the distance metric is \"Soft-DTW\" which may produce higher separation at a higher compute cost. Please see the `tslearn` documentation [link](https://tslearn.readthedocs.io/en/stable/auto_examples/clustering/plot_kmeans.html#sphx-glr-auto-examples-clustering-plot-kmeans-py) for more details.\n",
    "\n",
    "Finding the optimal number of clusters to use - Adding more clusters decreases the inertia value but the information contained in each cluster also decreases. Hence we want to have a small cluster size with a relatively small inertia value. As with finding the optimal number of clusters, the elbow heuristic works well here. Also, clustering techniques are not advised for datasets with fewer than a thousand time series since this could have limiting effect on deep learning models.\n",
    "\n",
    "The next cell execution could take around 140 minutes with 16 vCPU (ml.c5.4xlarge). If you want you could skip it and move to a next. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67167ab6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2h 18min 17s, sys: 41min 20s, total: 2h 59min 37s\n",
      "Wall time: 1h 24min 44s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Text(0, 0.5, 'WCSSS')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 1440x720 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "# algorithm configuration\n",
    "algo = \"DTW_kmeans\"\n",
    "metric = \"dtw\"\n",
    "\n",
    "wcss = []\n",
    "\n",
    "for i in range(1,5):\n",
    "    model= TimeSeriesKMeans(n_clusters=i,metric=metric,n_jobs=-1,random_state=0)\n",
    "    model.fit_predict(X_train)\n",
    "    wcss.append(model.inertia_)\n",
    "\n",
    "plt.figure(figsize=(20,10))\n",
    "plt.grid()\n",
    "plt.plot(range(1,5),wcss,marker='o',linestyle='--')\n",
    "plt.xlabel('number of clusters')\n",
    "plt.ylabel('WCSSS');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9aadd724",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 45min 55s, sys: 13min 14s, total: 59min 10s\n",
      "Wall time: 30min 32s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "# algorithm configuration\n",
    "algo = \"DTW_kmeans\"\n",
    "metric = \"dtw\"\n",
    "\n",
    "# cluster configuration\n",
    "N_CLUSTERS = 2\n",
    "\n",
    "model= TimeSeriesKMeans(n_clusters=N_CLUSTERS,\n",
    "                        metric=metric,\n",
    "                        n_jobs=-1,\n",
    "                        random_state=0)\n",
    "\n",
    "y_pred = model.fit_predict(X_train)\n",
    "\n",
    "model.to_pickle(f\"./tsl/models/{algo}.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7d6e86f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# backup clustering results\n",
    "np.save(f\"./data/tls_{algo}_cluster_labels\", y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9287ee9b",
   "metadata": {},
   "source": [
    "Let us plot the different clusters to visually inspect the homogeneity of the cluster composition."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "64bd629f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7.23 s, sys: 364 ms, total: 7.59 s\n",
      "Wall time: 6.88 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "for yi in range(N_CLUSTERS):\n",
    "    X_sub = X_train[y_pred == yi]\n",
    "    ts_cnt = pd.Series(y_pred[y_pred == yi]).shape[0]\n",
    "    fig = plt.figure()\n",
    "    plt.title(f\"{algo} | Cluster ID: {yi} | TS Count: {ts_cnt}\")\n",
    "    for xx in X_sub:\n",
    "        plt.plot(xx.ravel(), color='xkcd:sky blue', alpha=0.025)\n",
    "    fig.savefig(f\"./tsl/plots/{algo}_cls_lbl_{yi}.png\", dpi=150)\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a036937f",
   "metadata": {},
   "source": [
    "## Generate TTS for different clusters\n",
    "\n",
    "We can now split the TTS into clusters based on the labels for the different items"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8facf29a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1722270, 3) item_id          object\n",
      "timestamp        object\n",
      "target_value    float64\n",
      "dtype: object\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>target_value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>10002</td>\n",
       "      <td>2009-12-01</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>10080</td>\n",
       "      <td>2009-12-01</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>10109</td>\n",
       "      <td>2009-12-01</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>10120</td>\n",
       "      <td>2009-12-01</td>\n",
       "      <td>60.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>10125</td>\n",
       "      <td>2009-12-01</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  item_id   timestamp  target_value\n",
       "0   10002  2009-12-01          12.0\n",
       "1   10080  2009-12-01           0.0\n",
       "2   10109  2009-12-01           0.0\n",
       "3   10120  2009-12-01          60.0\n",
       "4   10125  2009-12-01           5.0"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# convert data to TTS format expected by Forecast service\n",
    "df_pivoted.reset_index(inplace=True)\n",
    "\n",
    "df_tts = pd.melt(df_pivoted, id_vars=['StockCode'])\n",
    "df_tts.columns = ['item_id', 'timestamp', 'target_value']\n",
    "df_tts['timestamp'] = df_tts['timestamp'].str[:10]  # keep only the date part\n",
    "\n",
    "print(df_tts.shape, df_tts.dtypes)\n",
    "\n",
    "df_tts.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c437a885",
   "metadata": {},
   "source": [
    "### Train - Hold-out Split\n",
    "\n",
    "Hold-out set offers a way for verifying model performance on unseen data. With this dataset, we are looking to forecast a week out (Forecast Horizon = 1 Week) and therefore leave out a week worth of data out from the TTS as holdout set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7d28a9e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('2009-12-01', '2010-12-09')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "min(df_tts['timestamp']), max(df_tts['timestamp'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "65803ce0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1722270, 3), (1690035, 3), (32235, 3))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train = df_tts[df_tts['timestamp'] < '2010-12-03']\n",
    "df_test = df_tts[df_tts['timestamp'] > '2010-12-02']\n",
    "\n",
    "df_tts.shape, df_train.shape, df_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ba2c090c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4605, 4605, 4605)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# verify that we have adequate coverage across train and test\n",
    "df_tts.item_id.nunique(), df_train.item_id.nunique(), df_test.item_id.nunique()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fbdbb96",
   "metadata": {},
   "source": [
    "## Split data into clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "adf72f6d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 1 0 ... 0 0 1]\n"
     ]
    }
   ],
   "source": [
    "# if restarting, reload the cluster labels\n",
    "y_pred = np.load(f\"./data/tls_{algo}_cluster_labels.npy\")\n",
    "\n",
    "print(y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7e4a6deb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((4605, 2),\n",
       " item_id    object\n",
       " label       int64\n",
       " dtype: object)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# lookup dataframe with item_ids and corresponding labels\n",
    "df_lbl = pd.DataFrame()\n",
    "df_lbl['item_id'] = df_pivoted['StockCode']\n",
    "df_lbl['label'] = y_pred\n",
    "\n",
    "df_lbl.shape, df_lbl.dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b03bc991",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1690035, 4) (1690035, 3)\n",
      "(32235, 4) (32235, 3)\n"
     ]
    }
   ],
   "source": [
    "# merge labels back to the TTS\n",
    "df_mrg = df_train.merge(df_lbl, how='left')\n",
    "df_mrg_test = df_test.merge(df_lbl, how='left')\n",
    "\n",
    "print(df_mrg.shape, df_train.shape)\n",
    "print(df_mrg_test.shape, df_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "01d4a5ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>item_id</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>target_value</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1191247</th>\n",
       "      <td>35638A</td>\n",
       "      <td>2010-08-16</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>320695</th>\n",
       "      <td>90188</td>\n",
       "      <td>2010-02-08</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1111322</th>\n",
       "      <td>22322</td>\n",
       "      <td>2010-07-30</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1531250</th>\n",
       "      <td>72778</td>\n",
       "      <td>2010-10-29</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1668645</th>\n",
       "      <td>22450</td>\n",
       "      <td>2010-11-28</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        item_id   timestamp  target_value  label\n",
       "1191247  35638A  2010-08-16           0.0      0\n",
       "320695    90188  2010-02-08           0.0      0\n",
       "1111322   22322  2010-07-30           0.0      1\n",
       "1531250   72778  2010-10-29           0.0      1\n",
       "1668645   22450  2010-11-28           0.0      0"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_mrg.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1d5b04b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create required directory structure\n",
    "dir_paths = ['./train']\n",
    "for i in range(N_CLUSTERS):\n",
    "    dir_paths.append(f\"./train/cls_{i+1}_DTW\")\n",
    "\n",
    "for dir_path in dir_paths:\n",
    "    if not os.path.exists(dir_path):\n",
    "        os.makedirs(dir_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "a910bf10",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1690035 1690035\n"
     ]
    }
   ],
   "source": [
    "# split and save TTS\n",
    "record_count = 0\n",
    "for i in range(N_CLUSTERS):\n",
    "    df_tmp = df_mrg[['item_id', 'timestamp', 'target_value']][df_mrg['label']==i]\n",
    "    df_tmp.to_csv(f\"./train/cls_{i+1}_DTW/tts_{i+1}_DTW.csv\", header=None, index=None)\n",
    "    df_tmp2 = df_mrg_test[['item_id', 'timestamp', 'target_value']][df_mrg_test['label']==i]\n",
    "    df_tmp2.to_csv(f\"./train/cls_{i+1}_DTW/test_{i+1}_DTW.csv\", header=None, index=None)\n",
    "    record_count += df_tmp.shape[0]\n",
    "    \n",
    "print(record_count, df_mrg.shape[0])  # verify that all time series are retained"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "985b9119",
   "metadata": {},
   "source": [
    "### Processing Complete\n",
    "\n",
    "These TTS files can now be uploaded to S3 and used to train Forecast models as described in the [Forecast Developers Guide](https://docs.aws.amazon.com/forecast/latest/dg/what-is-forecast.html)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}