{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train Detectron2 with SageMaker Training Jobs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define IAM role\n",
    "import boto3\n",
    "import re\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "! pip install sagemaker==2.15.0\n",
    "from sagemaker import get_execution_role\n",
    "\n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from time import gmtime, strftime\n",
    "\n",
    "sess = sagemaker.Session() # can use LocalSession() to run container locally\n",
    "\n",
    "bucket = '<MY_BUCKET>'\n",
    "region = \"<MY_REGION>\"\n",
    "account = sess.boto_session.client('sts').get_caller_identity()['Account']\n",
    "\n",
    "# Note: Upload your COCO data from the previous step into S3 at the `prefix_input` location below. We recommend using `aws s3 sync `\n",
    "# Where your COCO data resides in S3 \n",
    "prefix_input = 'training/data'\n",
    "\n",
    "# Where you'd like your training output to be stored \n",
    "prefix_output = 'training/d2-output'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configure Training Job"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define algorithm metrics which Sagemaker will scrap, persist, and render in training job console"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \n",
    "\n",
    "metric_definitions=[\n",
    "    {\n",
    "        \"Name\": \"total_loss\",\n",
    "        \"Regex\": \".*total_loss:\\s([0-9\\\\.]+)\\s*\"\n",
    "    },\n",
    "    {\n",
    "        \"Name\": \"loss_cls\",\n",
    "        \"Regex\": \".*loss_cls:\\s([0-9\\\\.]+)\\s*\"\n",
    "    },\n",
    "    {\n",
    "        \"Name\": \"loss_box_reg\",\n",
    "        \"Regex\": \".*loss_box_reg:\\s([0-9\\\\.]+)\\s*\"\n",
    "    },\n",
    "    {\n",
    "        \"Name\": \"loss_mask\",\n",
    "        \"Regex\": \".*loss_mask:\\s([0-9\\\\.]+)\\s*\"\n",
    "    },\n",
    "    {\n",
    "        \"Name\": \"loss_rpn_cls\",\n",
    "        \"Regex\": \".*loss_rpn_cls:\\s([0-9\\\\.]+)\\s*\"\n",
    "    },\n",
    "    {\n",
    "        \"Name\": \"loss_rpn_loc\",\n",
    "        \"Regex\": \".*loss_rpn_loc:\\s([0-9\\\\.]+)\\s*\"\n",
    "    }, \n",
    "    {\n",
    "        \"Name\": \"overall_training_speed\",\n",
    "        \"Regex\": \".*Overall training speed:\\s([0-9\\\\.]+)\\s*\"\n",
    "    },\n",
    "    {\n",
    "        \"Name\": \"lr\",  \n",
    "        \"Regex\": \".*lr:\\s([0-9\\\\.]+)\\s*\"\n",
    "    },\n",
    "    {\n",
    "        \"Name\": \"iter\",  \n",
    "        \"Regex\": \".*iter:\\s([0-9\\\\.]+)\\s*\"\n",
    "    },\n",
    "    {\n",
    "        \"Name\": \"AP <OBJECT1>\", # Duplicate this dictionary object as much as you want! This will plot Average Precision for <OBJECT1> on your training job\n",
    "        \"Regex\": \".*<OBJECT1>\\s*\\W\\S\\s([0-9\\.]+)\\s*\"\n",
    "    }, \n",
    "    {\n",
    "        \"Name\": \"AP <OBJECT2>\", # Duplicate this dictionary object as much as you want! This will plot Average Precision for <OBJECT2> on your training job\n",
    "        \"Regex\": \".*<OBJECT2>\\s*\\W\\S\\s([0-9\\.]+)\\s*\"\n",
    "    }, \n",
    "    {\n",
    "        \"Name\": \"Estimated Training Time Left\",\n",
    "        \"Regex\": \".*eta:\\s([0-9\\\\.]+)\\s*\"\n",
    "    }\n",
    "]\n",
    "\n",
    "print(f\"s3://{bucket}/{prefix_input}/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d2_configs = [\n",
    "    'MODEL.ROI_HEADS.NUM_CLASSES', '2',\n",
    "    'SOLVER.REFERENCE_WORLD_SIZE', '8',\n",
    "    'SOLVER.MAX_ITER', '100', # uncomment if want to do small experiment\n",
    "    'MODEL.WEIGHTS', 'https://dl.fbaipublicfiles.com/detectron2/COCO-Detection/faster_rcnn_R_101_FPN_3x/137851257/model_final_f6e8b1.pkl',\n",
    "#     'INPUT.MIN_SIZE_TRAIN'\n",
    "#     'INPUT.MAX_SIZE_TRAIN'\n",
    "#     'INPUT.MIN_SIZE_TEST'\n",
    "#     'INPUT.MAX_SIZE_TEST'    \n",
    "#     'INPUT.CROP.TYPE', 'relative_range',\n",
    "#     'INPUT.CROP.SIZE', '(0.9, 0.9)',\n",
    "#     INPUT.FORMAT -- VERIFY THIS! Needed? Maybe only for seg \n",
    "    'MODEL.BACKBONE.FREEZE_AT', '2', # There are 5 stages in ResNet. The first is a convolution, and the followingstages are each group of residual blocks.\n",
    "#     MODEL.ANCHOR_GENERATOR.NAME = \"DefaultAnchorGenerator\"\n",
    "#     MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]]\n",
    "#     SOLVER.LR_SCHEDULER_NAME = \"WarmupMultiStepLR\"\n",
    "#     SOLVER.BASE_LR = 0.001\n",
    "#     SOLVER.MOMENTUM = 0.9\n",
    "#     SOLVER.WEIGHT_DECAY = 0.0001\n",
    "#     SOLVER.WEIGHT_DECAY_NORM = 0.0\n",
    "#     SOLVER.GAMMA = 0.1\n",
    "#     SOLVER.STEPS = (30000,)\n",
    "]\n",
    "\n",
    "' '.join(d2_configs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from sagemaker.pytorch import PyTorch\n",
    "\n",
    "hyperparameters = { \"local-config-file\":\"faster_rcnn_R_101_FPN_3x.yaml\", \n",
    "                   \"resume\":\"True\", \n",
    "                   \"opts\": ' '.join(d2_configs) # https://detectron2.readthedocs.io/modules/config.html#config-references\n",
    "                   }\n",
    "\n",
    "d2 = PyTorch('train.py',\n",
    "             role=role,\n",
    "             max_run=3*24*60*60, # 3 days in seconds\n",
    "             source_dir='source',\n",
    "             framework_version='1.6.0',\n",
    "             py_version='py3',\n",
    "             instance_count=1,\n",
    "             instance_type='ml.p3.16xlarge',\n",
    "             volume_size=100,\n",
    "             output_path=\"s3://{}/{}\".format(bucket, prefix_output),\n",
    "             metric_definitions = metric_definitions,\n",
    "             hyperparameters = hyperparameters,\n",
    "             sagemaker_session=sess)\n",
    "\n",
    "d2.fit(f\"s3://{bucket}/{prefix_input}\",\n",
    "       job_name = \"d2-model\",\n",
    "       wait=True) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d2.latest_training_job.describe()"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}