{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hierarchical Forecasting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook we take the example of demand forecasting on synthetic retail data and show you how to train and tune multiple hierarchichal time series models across algorithms and hyper-parameter combinations using the `scikit-hts` toolkit on Amazon SageMaker. We will first show you how to setup scikit-hts on SageMaker using the SKLearn estimator, then train multiple models using SageMaker Experiments, and finally use SageMaker Debugger to monitor suboptimal training and improve training efficiencies. We will walk you through the following steps:\n",
    "\n",
    "1.\t[Setup](#Setup)\n",
    "2.\t[Prepare Time Series Data](#Prepare-Time-Series-Data)\n",
    "    - [Data Visualization](#Data-Visualization)\n",
    "    - [Split data into train and test](#Split-data-into-train-and-test)\n",
    "    - [Hierarchical Representation](#Hierarchical-Representation)\n",
    "    - [Visualizing the tree structure](#Visualizing-the-tree-structure)\n",
    "3.\t[Setup the scikit-hts training script](#section3)\n",
    "4.  [Setup Amazon SageMaker Experiment and Trials](#section4)\n",
    "5.\t[Setup the SKLearn Estimator](#section5)\n",
    "6.\t[Evaluate metrics and select a winning candidate](#section7)\n",
    "7.\t[Run time series forecasts](#Run-time-series-forecasts)\n",
    "    - [Visualization at Region Level](#Visualization-at-Region-Level)\n",
    "    - [Visualization at State Level](#Visualization-at-State-Level)\n",
    "\n",
    "\n",
    "Before getting started we need to first install a few packages:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install upgrade pip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --user scikit-hts[prophet]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#! pip install --user plotly -q \n",
    "! pip install --user scikit-hts -q \n",
    "! pip install --user -U sagemaker -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!conda install -c plotly plotly --yes\n",
    "!conda install -c conda-forge fbprophet --yes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Important** -- Make sure to restart the kernel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.graph_objects as go\n",
    "from hts.hierarchy import HierarchyTree\n",
    "from hts import HTSRegressor\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pandas import DataFrame\n",
    "from pandas import Grouper\n",
    "from matplotlib import pyplot\n",
    "import pandas as pd\n",
    "import dataset_prep"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.CRITICAL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker \n",
    "\n",
    "s3_client = boto3.client('s3')\n",
    "s3res = boto3.resource('s3')\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "\n",
    "pref = 'hierarchical-forecast-retail/scikit-hts'\n",
    "s3_train_channel = \"s3://\" + bucket + \"/\" + pref + \"/train.csv\"\n",
    "s3_test_channel = \"s3://\" + bucket + \"/\" + pref + \"/test.csv\"\n",
    "print(s3_train_channel)\n",
    "print(s3_test_channel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Time Series Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Read cleaned, joined, featurized data.\n",
    "df_raw = pd.read_csv(\"retail-usa-clothing.csv\"\n",
    "                          , parse_dates=True\n",
    "                          , header=0\n",
    "                          , names=['date', 'state'\n",
    "                                   , 'item', 'quantity', 'region'\n",
    "                                   , 'country']\n",
    "                    )\n",
    "df_raw['quantity'] = df_raw['quantity'].astype(int)\n",
    "# drop duplicates\n",
    "print(df_raw.shape)\n",
    "df_raw.drop_duplicates(inplace=True)\n",
    "\n",
    "df_raw['date'] = pd.to_datetime(df_raw[\"date\"])\n",
    "print(df_raw.shape)\n",
    "print(df_raw.dtypes)\n",
    "print(f\"Min timestamp = {df_raw.date.min()}\")\n",
    "print(f\"Max timestamp = {df_raw.date.max()}\")\n",
    "df_raw.sample(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_raw.region.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_raw.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## # map expected column names\n",
    "item_id = \"item\"\n",
    "target_value = \"quantity\"\n",
    "timestamp = \"date\"\n",
    "city = \"city\"\n",
    "region = 'region'\n",
    "country = 'country'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Drop null item_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Drop null item_ids\n",
    "templist = df_raw[item_id].unique()\n",
    "print(f\"Number unique items: {len(templist)}\")\n",
    "print(f\"Number nulls: {pd.isnull(templist).sum()}\")\n",
    "\n",
    "if len(templist) < 20:\n",
    "    print(templist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Drop the null item_ids, if any exist\n",
    "if pd.isnull(templist).sum() > 0:\n",
    "    print(df_raw.shape)\n",
    "    df_raw = df_raw.loc[(~df_raw[item_id].isna()), :].copy()\n",
    "    print(df_raw.shape)\n",
    "    print(len(df_raw[item_id].unique()))\n",
    "else:\n",
    "    print(\"No missing item_ids found.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Drop null timestamps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check null timestamps\n",
    "templist = df_raw.loc[(df_raw[timestamp].isna()), :].shape[0]\n",
    "print(f\"Number nulls: {templist}\")\n",
    "\n",
    "if (templist < 10) & (templist > 0) :\n",
    "    print(df_raw.loc[(df_raw[timestamp].isna()), :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Drop the null quantities and dates\n",
    "if templist > 0:\n",
    "    print(df_raw.shape)\n",
    "    df_raw = df_raw.loc[(~df_raw[timestamp].isna()), :].copy()\n",
    "    print(df_raw.shape)\n",
    "    print(df_raw['date'].isna().sum())\n",
    "else:\n",
    "    print(\"No null dates found.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pandas import DataFrame\n",
    "from pandas import Grouper\n",
    "from matplotlib import pyplot\n",
    "groups = df_raw.groupby(Grouper(key=item_id))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Plots\n",
    "for name, group in groups:\n",
    "    group.plot(subplots=True,title=name, x=timestamp, y=target_value, legend=True)\n",
    "    pyplot.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split data into train and test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = df_raw.query(f'date <= \"2009-04-29\"').copy()\n",
    "df_train.to_csv(\"train.csv\")\n",
    "s3_client.upload_file(\"train.csv\", bucket, pref+\"/train.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test = df_raw.query(f'date > \"2009-04-29\"').copy()\n",
    "df_test.to_csv(\"test.csv\")\n",
    "s3_client.upload_file(\"test.csv\", bucket, pref+\"/test.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "product = df_train[(df_train['item'] == \"mens_clothing\")]\n",
    "product.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_region_columns(df, region):\n",
    "    return [col for col in df.columns if region in col]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "product[\"region_state\"] = product.apply(lambda x: f\"{x['region']}_{x['state']}\", axis=1)\n",
    "region_states = product[\"region_state\"].unique()\n",
    "grouped_sections = product.groupby([\"region\", \"region_state\"])\n",
    "edges_hierarchy = list(grouped_sections.groups.keys())\n",
    "# Now, we must not forget that total is our root node.\n",
    "second_level_nodes = product.region.unique()\n",
    "root_node = \"total\"\n",
    "root_edges = [(root_node, second_level_node) for second_level_node in second_level_nodes]\n",
    "root_edges += edges_hierarchy\n",
    "product_bottom_level = product.pivot(index=\"date\", columns=\"region_state\", values=\"quantity\")\n",
    "regions = product[\"region\"].unique().tolist()\n",
    "for region in regions:\n",
    "    region_cols = get_region_columns(product_bottom_level, region)\n",
    "    product_bottom_level[region] = product_bottom_level[region_cols].sum(axis=1)\n",
    "\n",
    "product_bottom_level[\"total\"] = product_bottom_level[regions].sum(axis=1)\n",
    "\n",
    "# create hierarchy\n",
    "# Now that we have our dataset ready, let's define our hierarchy tree. \n",
    "# We need a dictionary, where each key is a column (node) in our hierarchy and a list of its children.\n",
    "hierarchy = dict()\n",
    "\n",
    "for edge in root_edges:\n",
    "    parent, children = edge[0], edge[1]\n",
    "    hierarchy.get(parent)\n",
    "    if not hierarchy.get(parent):\n",
    "        hierarchy[parent] = [children]\n",
    "    else:\n",
    "        hierarchy[parent] += [children]\n",
    "\n",
    "product_bottom_level.index = pd.to_datetime(product_bottom_level.index)\n",
    "product_bottom_level = product_bottom_level.resample(\"D\").sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hierarchical Representation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*scikit-hts* requires that each column in our DataFrame is a time series of its own, for all hierarchy levels. Let's do that. Remember that our data is in a long format.\n",
    "\n",
    "The steps are the following:\n",
    "\n",
    "1. Transform dataset into a column oriented one\n",
    "2. Create the hierarchy representation as a dictionary\n",
    " \n",
    "For a complete description of how that is done under the hood, and for a sense of what the API accepts, see [scikit-hts' docs](https://scikit-hts.readthedocs.io/en/latest/hierarchy.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Let's take a look at the training script\n",
    "!pygmentize code/dataset_prep.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_hierarchy, train_product_bottom_level, region_states = dataset_prep.prepare_data(df_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_hierarchy, test_product_bottom_level, region_states = dataset_prep.prepare_data(df_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing the tree structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hts.hierarchy import HierarchyTree\n",
    "\n",
    "ht = HierarchyTree.from_nodes(nodes=train_hierarchy, df=train_product_bottom_level)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ht)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ht.children[1].key)\n",
    "print(ht.children[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ht.get_node('NewEngland'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "regions = df_raw[\"region\"].unique().tolist()\n",
    "regions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the algorithm and hyper-parameters combinatorial matrix <a name=section2></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "d = {'revision_method': [\"BU\", \"AHP\"], 'seasonality_mode':[\"additive\", \"multiplicative\"]}\n",
    "df_hps = pd.DataFrame(data=d)\n",
    "df_hps.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use the 'product' function to derive combinations of these parameters from the base set into separate rows in the dataframe. Each row corresponds to a training job configuration that we will subsequently pass to the SKLearn Estimator to run the training job.\n",
    "\n",
    "Note Please check your AWS account limits before you setup the product function below. The training process in the sections below will run one training job per row from this dataframe. Based on your account limit for the maximum number of concurrent training jobs, you may get an error that the limit has been exceeded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "\n",
    "prod = product(df_hps['seasonality_mode'].unique(), df_hps['revision_method'].unique())\n",
    "\n",
    "df_hps_combo = pd.DataFrame([list(p) for p in prod],\n",
    "                   columns=list(['seasonality_mode', 'revision_method']))\n",
    "\n",
    "df_hps_combo['jobnumber'] = df_hps_combo.index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look on the different combinations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_hps_combo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup the scikit-hts training script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Let's take a look at the training script\n",
    "!pygmentize code/train.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup a SageMaker Experiment  <a name=section4></a>\n",
    "\n",
    "Before create the training job, we first create a SageMaker Experiment that will allow us to track the different training jobs. We use the `smexperiments` libraray to create the experiment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "from smexperiments.experiment import Experiment\n",
    "\n",
    "sagemaker_boto_client = boto3.client(\"sagemaker\")\n",
    "\n",
    "#name of experiment\n",
    "timestep = datetime.now()\n",
    "timestep = timestep.strftime(\"%d-%m-%Y-%H-%M-%S\")\n",
    "experiment_name = \"hierarchical-forecast-models-\" + timestep\n",
    "\n",
    "#create experiment\n",
    "Experiment.create(\n",
    "    experiment_name=experiment_name, \n",
    "    description=\"Hierarchical Timeseries models\", \n",
    "    sagemaker_boto_client=sagemaker_boto_client)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each job we define a new Trial component within that experiment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from smexperiments.trial import Trial\n",
    "\n",
    "trial = Trial.create(\n",
    "    experiment_name=experiment_name,\n",
    "    sagemaker_boto_client=sagemaker_boto_client\n",
    ")\n",
    "print(trial)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_config = { \"ExperimentName\": experiment_name, \n",
    "                      \"TrialName\":  trial.trial_name,\n",
    "                      \"TrialComponentDisplayName\": \"Training\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fitting models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use scikit-hts to fit `Prophet` model in our data and compare results.\n",
    "- Prophet: \n",
    "    - *daily_seasonality* : By default daily seasonality is set to `False`, therefore, explicitly changing it to `True`\n",
    "    - *changepoint_prior_scale* : If the trend changes are being overfit (too much flexibility) or underfit (not enough flexibility), you can adjust the strength of the sparse prior using the input argument changepoint_prior_scale. By default, this parameter is set to `0.05`. Increasing it will make the trend more flexible.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_definitions=[\n",
    "    {'Name': 'Total:MSE', 'Regex': 'Total: MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Mid-Alantic:MSE', 'Regex': 'Mid-Alantic: MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'SouthCentral:MSE', 'Regex': 'SouthCentral: MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Pacific:MSE', 'Regex': 'Pacific: MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'EastNorthCentral:MSE', 'Regex': 'EastNorthCentral: MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'NewEngland:MSE', 'Regex': 'NewEngland: MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'NewYork:MSE', 'Regex': 'NewYork:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Alabama:MSE', 'Regex': 'Alabama:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Alaska:MSE', 'Regex': 'Alaska:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Kentucky:MSE', 'Regex': 'Kentucky:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Illinois:MSE', 'Regex': 'Illinois:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Mississippi:MSE', 'Regex': 'Mississippi:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Hawaii:MSE', 'Regex': 'Hawaii:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Indiana:MSE', 'Regex': 'Indiana:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'NewJersey:MSE', 'Regex': 'NewJersey:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Pennsylvania:MSE', 'Regex': 'Pennsylvania:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Tennessee:MSE', 'Regex': 'Tennessee:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'California:MSE', 'Regex': 'California:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'RhodeIsland:MSE', 'Regex': 'RhodeIsland:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Oregon:MSE', 'Regex': 'Oregon:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Connecticut:MSE', 'Regex': 'Connecticut:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Maine:MSE', 'Regex': 'Maine:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Ohio:MSE', 'Regex': 'Ohio:MSE: ([0-9\\\\.]+)'},\n",
    "    {'Name': 'Vermont:MSE', 'Regex': 'Vermont:MSE: ([0-9\\\\.]+)'},\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create the SKLearn Estimator  <a name=section5></a>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker.sklearn import SKLearn\n",
    "\n",
    "for idx, row in df_hps_combo.iterrows():\n",
    "    trial = Trial.create(\n",
    "        experiment_name=experiment_name,\n",
    "        sagemaker_boto_client=sagemaker_boto_client\n",
    "    )\n",
    "\n",
    "    experiment_config = { \"ExperimentName\": experiment_name, \n",
    "                      \"TrialName\":  trial.trial_name,\n",
    "                      \"TrialComponentDisplayName\": \"Training\"}\n",
    "    \n",
    "\n",
    "    sklearn_estimator = SKLearn('train.py',\n",
    "                                source_dir='code',\n",
    "                                instance_type='ml.m4.xlarge',\n",
    "                                framework_version='0.23-1',\n",
    "                                role=sagemaker.get_execution_role(),\n",
    "                                debugger_hook_config=False,\n",
    "                                hyperparameters = {'bucket': bucket,\n",
    "                                                   'algo': \"Prophet\", \n",
    "                                                   'daily_seasonality': True,\n",
    "                                                   'changepoint_prior_scale': 0.5,\n",
    "                                                   'seasonality_mode': row['seasonality_mode'],\n",
    "                                                   'revision_method' : row['revision_method']\n",
    "                                                  },\n",
    "                                metric_definitions = metric_definitions,\n",
    "                               )\n",
    "    sklearn_estimator.fit({'train': s3_train_channel, \"test\": s3_test_channel},\n",
    "                     experiment_config=experiment_config, wait=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the experiment is finished we can determine how many seconds it ran. First we define a helper function to compute the billabale seconds and how many training jobs were auto-terminated."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate metrics and select a winning candidate <a name=section8></a>\n",
    "Amazon SageMaker Studio provides an experiments browser that you can use to view lists of experiments, trials, and trial components. You can choose one of these entities to view detailed information about the entity or choose multiple entities for comparison. For more details please refer to [the documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/experiments-view-compare.html#experiments-view). Once the training jobs are running we can use the experiment view in Studio (see screenshot below) or the `ExperimentAnalytics` module to track the status of our training jobs and their metrics. \n",
    "![](screenshot.png)\n",
    "\n",
    "\n",
    "In the training script we used SageMaker Debugger's function `save_scalar` to store metrics such as MAPE, MSE, RMSE in the experiment. We can access the recorded metrics via the ExperimentAnalytics function and convert it to a Pandas dataframe.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_job_statistics(df):\n",
    "    total_cost  = 0\n",
    "    stopped = 0\n",
    "    for name in df['sagemaker_job_name']:\n",
    "        description = sagemaker_boto_client.describe_training_job(TrainingJobName=name[1:-1])\n",
    "        total_cost += description['BillableTimeInSeconds']\n",
    "        if description['TrainingJobStatus'] == \"Stopped\":\n",
    "            stopped += 1\n",
    "    return stopped, total_cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.analytics import ExperimentAnalytics\n",
    "trial_component_analytics = ExperimentAnalytics(experiment_name=experiment_name)\n",
    "\n",
    "stopped, total_cost = compute_job_statistics(trial_component_analytics.dataframe())\n",
    "print(\"Billable seconds for overall experiment:\", total_cost, \"seconds. Number of training jobs auto-terminated:\", stopped)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This setup is especially useful if you run a parameter sweep with training jobs that train for hours. In our case each job only trained for less than 10 minutes. Until the Debugger data is uploaded, fetched and downloaded into the processing job, a few minutes may pass, so the potential cost reduction will be less for smaller training jobs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.analytics import ExperimentAnalytics\n",
    "\n",
    "trial_component_analytics = ExperimentAnalytics(experiment_name=experiment_name)\n",
    "tc_df = trial_component_analytics.dataframe()\n",
    "tc_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_mse = []\n",
    "model_url = []\n",
    "for name in tc_df['sagemaker_job_name']:\n",
    "        description = sagemaker_boto_client.describe_training_job(TrainingJobName=name[1:-1])\n",
    "        total_mse.append(description['FinalMetricDataList'][0]['Value'])\n",
    "        model_url.append(description['ModelArtifacts']['S3ModelArtifacts'])\n",
    "tc_df['total_mse'] = total_mse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "description['FinalMetricDataList'][0]['Value']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look on the metrics and hyperparameter combinations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_df = tc_df[['sagemaker_job_name','algo', 'changepoint_prior_scale', 'revision_method', 'total_mse', 'seasonality_mode']]\n",
    "new_df  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_min = new_df['total_mse'].min()\n",
    "mse_min"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's select the winner model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_winner = new_df[new_df['total_mse'] == mse_min]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the winning model for running forecasts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name in df_winner['sagemaker_job_name']:\n",
    "    model_dir = sagemaker_boto_client.describe_training_job(TrainingJobName = name[1:-1])['ModelArtifacts']['S3ModelArtifacts']\n",
    "model_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = model_dir.split('s3://{}/'.format(bucket))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_client.download_file(bucket, key[1], 'model.tar.gz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar -xvzf model.tar.gz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run time series forecasts\n",
    "First, let's load the model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import joblib\n",
    "def model_fn(model_dir):\n",
    "    clf = joblib.load(model_dir)\n",
    "    return clf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model_fn('model.joblib')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's make forecasts 90 days in future."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = model.predict(steps_ahead=90)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's visualize the model results and fitted values for all states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Helper function for plotting results.\n",
    "import matplotlib\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.graph_objects as go\n",
    "def plot_results(cols, axes, preds):\n",
    "    axes = np.hstack(axes)\n",
    "    for ax, col in zip(axes, cols):\n",
    "        preds[col].plot(ax=ax, label=\"Predicted\")\n",
    "        train_product_bottom_level[col].plot(ax=ax, label=\"Observed\")\n",
    "        ax.legend()\n",
    "        ax.set_title(col)\n",
    "        ax.set_xlabel(\"Date\")\n",
    "        ax.set_ylabel(\"Quantity\")    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Visualization at Region Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(len(regions), 1, figsize=(10, 20))\n",
    "plot_results(regions, axes, predictions)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Visualization at State Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(len(region_states), 1, figsize=(20, 70))\n",
    "plot_results(region_states, axes, predictions)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}