{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Parallel Training\n",
    "\n",
    "This notebook trains the ENet model on a number of GPUs in a `ml.p3.16xlarge` instance\n",
    "using [SageMaker's Distributed Model Parallel](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html) library.\n",
    "\n",
    "A prerequisite for model training is a preprocessed dataset which is done in a [separate notebook](preprocess-camvid.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and Paths\n",
    "\n",
    "The next cell imports modules from the [Amazon SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/)\n",
    "that we need for training the model, sets up a SageMaker session,\n",
    "and then defines the S3 URIs for the preprocessed data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%reload_ext dotenv\n",
    "%dotenv\n",
    "\n",
    "import sagemaker\n",
    "from sagemaker.tensorflow import TensorFlow\n",
    "from sagemaker.inputs import TrainingInput\n",
    "\n",
    "session = sagemaker.Session()\n",
    "bucket = session.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "training_role = role\n",
    "\n",
    "prefix = 'enet-tensorflow-distributed'\n",
    "train_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/train/'\n",
    "train_labels_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/train_labels/'\n",
    "val_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/val/'\n",
    "val_labels_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/val_labels/'\n",
    "test_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/test/'\n",
    "test_labels_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/test_labels/'\n",
    "report_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/report/'\n",
    "preprocessing_report_path = f'{report_path}preprocessing_report.json'\n",
    "class_dict_path = f'{report_path}class_dict.json'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define Training Job\n",
    "\n",
    "Since the ENet model is implemented in TensorFlow, we're using the [`TensorFlow estimator`](https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/sagemaker.tensorflow.html) to train it via Amazon SageMaker\n",
    "using a custom [training script](../scripts/train_model_parallel.py) (set via the `source_dir` and `entry_point` arguments).\n",
    "\n",
    "We also set the model's hyperparameters,\n",
    "as well as metric definitions that allow us to extract training metrics from log output.\n",
    "\n",
    "For cost efficiency we're using [managed spot training](https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html) (by setting `use_spot_instances=True` and providing `max_run` and `max_wait`).\n",
    "\n",
    "For model parallel training we provide a [`distribution`](https://sagemaker.readthedocs.io/en/stable/api/training/smd_model_parallel_general.html#configuration-parameters-for-distribution) argument which configures the distributed training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparameters = {\n",
    "    'dropout-rate1': 0.01,\n",
    "    'dropout-rate2': 0.1,\n",
    "    'batch-size': 4,\n",
    "    'learning-rate': 0.001,\n",
    "    'epochs': 25,\n",
    "}\n",
    "metric_definitions = [\n",
    "    {'Name': 'Epoch', 'Regex': r'# epoch = (\\d+)'},\n",
    "    {'Name': 'Loss', 'Regex': r'# loss = ([\\d.\\-\\+e]+)'},\n",
    "    {'Name': 'Val Loss', 'Regex': r'# val_loss = ([\\d.\\-\\+e]+)'},\n",
    "    {'Name': 'Mean IoU', 'Regex': r'# mean_iou = ([\\d.\\-\\+e]+)'},\n",
    "    {'Name': 'Val Mean IoU', 'Regex': r'# val_mean_iou = ([\\d.\\-\\+e]+)'},\n",
    "]\n",
    "smp_options = {\n",
    "    'enabled': True,\n",
    "    'parameters': {\n",
    "        'partitions': 8,\n",
    "        'microbatches': 2,\n",
    "        'pipeline': 'interleaved',\n",
    "        'optimize': 'speed',\n",
    "        'auto_partition': True,\n",
    "    },\n",
    "}\n",
    "mpi_options = {\n",
    "    'enabled' : True,\n",
    "    'processes_per_host': 8,\n",
    "    'custom_mpi_options': '-verbose -x NCCL_DEBUG=VERSION',\n",
    "}\n",
    "estimator = TensorFlow(\n",
    "    base_job_name='enet-tf-mp-train',\n",
    "    py_version='py38',\n",
    "    framework_version='2.6.3',\n",
    "    model_dir='/opt/ml/model',\n",
    "    checkpoint_local_path='/opt/ml/checkpoints',\n",
    "    entry_point='scripts/train_model_parallel.py',\n",
    "    source_dir='../',\n",
    "    hyperparameters=hyperparameters,\n",
    "    metric_definitions=metric_definitions,\n",
    "    role=training_role,\n",
    "    sagemaker_session=session,\n",
    "    instance_count=1,\n",
    "    instance_type='ml.p3.16xlarge',\n",
    "    distribution={\n",
    "        'smdistributed': {'modelparallel': smp_options},\n",
    "        'mpi': mpi_options,\n",
    "    },\n",
    "    use_spot_instances=True,\n",
    "    max_run=10*3600,\n",
    "    max_wait=16*3600,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Training Job\n",
    "\n",
    "We then run the training job by invoking the TensorFlow estimator's `fit` method.\n",
    "As argument we provide the data [inputs](https://sagemaker.readthedocs.io/en/stable/api/utility/inputs.html) with the locations of the preprocessed dataset in S3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.fit({\n",
    "    'train': TrainingInput(train_path),\n",
    "    'train_labels': TrainingInput(train_labels_path),\n",
    "    'val': TrainingInput(val_path),\n",
    "    'val_labels': TrainingInput(val_labels_path),\n",
    "    'test': TrainingInput(test_path),\n",
    "    'test_labels': TrainingInput(test_labels_path),\n",
    "    'report': TrainingInput(report_path),\n",
    "}, wait=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "During training we can stream the logs of the training job to the notebook to follow its progress."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.logs()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "23e9edbf101ec1229ee15d5e8950818b02dd17cfc3730b3f1ee235cf2fb9b8d3"
  },
  "kernelspec": {
   "display_name": "Python 3.9.11 64-bit ('3.9.11')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}