{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cat and Dog dataset training notebook\n",
    "Purpose of this notebook is to show end to end machine learning workflow\n",
    "1. Use label dataset created by Sagemaker GroundTruth. Then split the dataset into train and validation. \n",
    "2. Train the model using Sagemaker training container, \n",
    "3. Deploy the model using Sagemaker endpont\n",
    "4. Lastly do model inference\n",
    "\n",
    "This notebook is based on the lab in the repository [https://github.com/mahendrabairagi/GroundTruth_lab](https://github.com/mahendrabairagi/GroundTruth_lab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Block below shows how to use GroundTruth labled dataset, then split data into training and validation\n",
    "### IMP: Please change \"BUCKET=\" to you your S3 bucket"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import boto3\n",
    "import sagemaker\n",
    "import time\n",
    "\n",
    "BUCKET = 'escience-workshop-{{FIXME}}'\n",
    "\n",
    "LABEL_OUTPUT_PREFIX = 'cat_dog_images_labeled' # Labeled images prefix from label job\n",
    "LABEL_JOB_NAME = 'groundtruth-labeling-job-cat-dog' # Label job name from previous notebook\n",
    "OUTPUT_MANIFEST = 's3://{0}/{1}/{2}/manifests/output/output.manifest'.format(BUCKET, LABEL_OUTPUT_PREFIX, LABEL_JOB_NAME)\n",
    "print(OUTPUT_MANIFEST)\n",
    "    \n",
    "EXP_NAME = 'catanddog-smalldataset-ml-lab' # Any valid S3 prefix.\n",
    "\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.session.Session().region_name\n",
    "s3 = boto3.client('s3')\n",
    "bucket_region = s3.head_bucket(Bucket=BUCKET)['ResponseMetadata']['HTTPHeaders']['x-amz-bucket-region']\n",
    "assert bucket_region == region, \"You S3 bucket {} and this notebook need to be in the same region.\".format(BUCKET)\n",
    "\n",
    "!aws s3 cp {OUTPUT_MANIFEST} 'output.manifest'\n",
    "\n",
    "with open('output.manifest', 'r') as f:\n",
    "    output = [json.loads(line) for line in f.readlines()]\n",
    "\n",
    "# Shuffle output in place.\n",
    "np.random.shuffle(output)\n",
    "    \n",
    "dataset_size = len(output)\n",
    "train_test_split_index = round(dataset_size*0.8)\n",
    "\n",
    "train_data = output[:train_test_split_index]\n",
    "validation_data = output[train_test_split_index:]\n",
    "\n",
    "num_training_samples = 0\n",
    "with open('mllab.train.manifest', 'w') as f:\n",
    "    for line in train_data:\n",
    "        f.write(json.dumps(line))\n",
    "        f.write('\\n')\n",
    "        num_training_samples += 1\n",
    "    \n",
    "with open('mllab.validation.manifest', 'w') as f:\n",
    "    for line in validation_data:\n",
    "        f.write(json.dumps(line))\n",
    "        f.write('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Upload training and validation dataset to S3 bucket, so that this dataset can be used by Sagemaker Training jobs later\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3.upload_file('mllab.train.manifest',BUCKET, EXP_NAME + '/mllab.train.manifest')\n",
    "s3.upload_file('mllab.validation.manifest',BUCKET, EXP_NAME + '/mllab.validation.manifest')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Sagemaker training job. \n",
    "\n",
    "Change hyperparamerter per training needs. For this workshop we will set the `mini_batch_size` to 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create unique job name \n",
    "nn_job_name_prefix = 'groundtruth-augmented-manifest-demo'\n",
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "nn_job_name = nn_job_name_prefix + timestamp\n",
    "num_classes = 2\n",
    "training_image = sagemaker.amazon.amazon_estimator.get_image_uri(boto3.Session().region_name, 'image-classification', repo_version='latest')\n",
    "\n",
    "training_params = \\\n",
    "{\n",
    "    \"AlgorithmSpecification\": {\n",
    "        \"TrainingImage\": training_image,\n",
    "        \"TrainingInputMode\": \"Pipe\"\n",
    "    },\n",
    "    \"RoleArn\": role,\n",
    "    \"OutputDataConfig\": {\n",
    "        \"S3OutputPath\": 's3://{}/{}/output/'.format(BUCKET, EXP_NAME)\n",
    "    },\n",
    "    \"ResourceConfig\": {\n",
    "        \"InstanceCount\": 1,   \n",
    "        \"InstanceType\": \"ml.p3.2xlarge\",\n",
    "        \"VolumeSizeInGB\": 50\n",
    "    },\n",
    "    \"TrainingJobName\": nn_job_name,\n",
    "    \"HyperParameters\": {\n",
    "        \"epochs\": \"30\",\n",
    "        \"image_shape\": \"3,224,224\",\n",
    "        \"learning_rate\": \"0.01\",\n",
    "        \"lr_scheduler_step\": \"10,20\",\n",
    "        \"mini_batch_size\": \"1\",\n",
    "        \"num_classes\": str(num_classes),\n",
    "        \"num_layers\": \"18\",\n",
    "        \"num_training_samples\": str(num_training_samples),\n",
    "        \"resize\": \"224\",\n",
    "        \"use_pretrained_model\": \"1\"\n",
    "    },\n",
    "    \"StoppingCondition\": {\n",
    "        \"MaxRuntimeInSeconds\": 86400\n",
    "    },\n",
    "    \"InputDataConfig\": [\n",
    "        {\n",
    "            \"ChannelName\": \"train\",\n",
    "            \"DataSource\": {\n",
    "                \"S3DataSource\": {\n",
    "                    \"S3DataType\": \"AugmentedManifestFile\",\n",
    "                    \"S3Uri\": 's3://{}/{}/{}'.format(BUCKET, EXP_NAME, 'mllab.train.manifest'),\n",
    "                    \"S3DataDistributionType\": \"FullyReplicated\",\n",
    "                    \"AttributeNames\": [\"source-ref\",\"groundtruth-labeling-job-cat-dog\"]\n",
    "                }\n",
    "            },\n",
    "            \"ContentType\": \"application/x-recordio\",\n",
    "            \"RecordWrapperType\": \"RecordIO\",\n",
    "            \"CompressionType\": \"None\"\n",
    "        },\n",
    "        {\n",
    "            \"ChannelName\": \"validation\",\n",
    "            \"DataSource\": {\n",
    "                \"S3DataSource\": {\n",
    "                    \"S3DataType\": \"AugmentedManifestFile\",\n",
    "                    \"S3Uri\": 's3://{}/{}/{}'.format(BUCKET, EXP_NAME, 'mllab.validation.manifest'),\n",
    "                    \"S3DataDistributionType\": \"FullyReplicated\",\n",
    "                    \"AttributeNames\": [\"source-ref\",\"groundtruth-labeling-job-cat-dog\"]\n",
    "                }\n",
    "            },\n",
    "            \"ContentType\": \"application/x-recordio\",\n",
    "            \"RecordWrapperType\": \"RecordIO\",\n",
    "            \"CompressionType\": \"None\"\n",
    "        }\n",
    "    ]\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Now we will create the SageMaker training job.](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html)\n",
    "\n",
    "We will create an SageMaker training job based on the configuration data above.\n",
    "\n",
    "[sagemaker_client.create_training_job](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job) boto3 documentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_client = boto3.client('sagemaker')\n",
    "sagemaker_client.create_training_job(**training_params)\n",
    "\n",
    "# Confirm that the training job has started\n",
    "print('Transform job started')\n",
    "while(True):\n",
    "    status = sagemaker_client.describe_training_job(TrainingJobName=nn_job_name)['TrainingJobStatus']\n",
    "    if status == 'Completed':\n",
    "        print(\"Transform job ended with status: \" + status)\n",
    "        break\n",
    "    if status == 'Failed':\n",
    "        message = response['FailureReason']\n",
    "        print('Transform failed with the following error: {}'.format(message))\n",
    "        raise Exception('Transform job failed') \n",
    "    time.sleep(30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Deploy the Model](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-hosting.html)\n",
    "Now that we've fully labeled our dataset and have a trained model, we want to use the model to perform inference.\n",
    "\n",
    "This section involves several steps,\n",
    "\n",
    "Create Model - Create model for the training output\n",
    "Host the model for realtime inference - Create an inference endpoint and perform realtime inference.\n",
    "\n",
    "[sagemaker_client.create_model](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model) boto3 documentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "model_name=\"groundtruth-demo-mllab-cat-dog-model\" + timestamp\n",
    "print(model_name)\n",
    "info = sagemaker_client.describe_training_job(TrainingJobName=nn_job_name)\n",
    "model_data = info['ModelArtifacts']['S3ModelArtifacts']\n",
    "print(model_data)\n",
    "\n",
    "primary_container = {\n",
    "    'Image': training_image,\n",
    "    'ModelDataUrl': model_data,\n",
    "}\n",
    "\n",
    "create_model_response = sagemaker_client.create_model(\n",
    "    ModelName = model_name,\n",
    "    ExecutionRoleArn = role,\n",
    "    PrimaryContainer = primary_container)\n",
    "\n",
    "print(create_model_response['ModelArn'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Realtime Inference Endpoint Configuraton](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateEndpointConfig.html)\n",
    "We now host the model with an endpoint and perform realtime inference.\n",
    "\n",
    "Create endpoint configuration - Create a configuration defining an endpoint.\n",
    "Create endpoint - Use the configuration to create an inference endpoint.\n",
    "Perform inference - Perform inference on some input data using the endpoint.\n",
    "Clean up - Delete the endpoint and model\n",
    "\n",
    "[sagemaker_client.create_endpoint_config](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint_config) boto3 documentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "job_name = 'mllab-ground-truth-demo-' + str(int(time.time()))\n",
    "\n",
    "endpoint_config_name = job_name + '-epc' + timestamp\n",
    "endpoint_config_response = sagemaker_client.create_endpoint_config(\n",
    "    EndpointConfigName = endpoint_config_name,\n",
    "    ProductionVariants=[{\n",
    "        'InstanceType':'ml.m4.xlarge',\n",
    "        'InitialInstanceCount':1,\n",
    "        'ModelName':model_name,\n",
    "        'VariantName':'AllTraffic'}])\n",
    "\n",
    "print('Endpoint configuration name: {}'.format(endpoint_config_name))\n",
    "print('Endpoint configuration arn:  {}'.format(endpoint_config_response['EndpointConfigArn']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Create Endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateEndpoint.html)\n",
    "Lastly, the customer creates the endpoint that serves up the model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes about 10 minutes to complete.\n",
    "\n",
    "[sagemaker_client.create_endpoint](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_endpoint) boto3 documentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())\n",
    "endpoint_name = job_name + '-ep' + timestamp\n",
    "print('Endpoint name: {}'.format(endpoint_name))\n",
    "\n",
    "endpoint_params = {\n",
    "    'EndpointName': endpoint_name,\n",
    "    'EndpointConfigName': endpoint_config_name,\n",
    "}\n",
    "endpoint_response = sagemaker_client.create_endpoint(**endpoint_params)\n",
    "print('EndpointArn = {}'.format(endpoint_response['EndpointArn']))\n",
    "\n",
    "# get the status of the endpoint\n",
    "response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)\n",
    "status = response['EndpointStatus']\n",
    "print('EndpointStatus = {}'.format(status))\n",
    "\n",
    "# wait until the status has changed\n",
    "sagemaker_client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)\n",
    "\n",
    "# print the status of the endpoint\n",
    "endpoint_response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)\n",
    "status = endpoint_response['EndpointStatus']\n",
    "print('Endpoint creation ended with EndpointStatus = {}'.format(status))\n",
    "\n",
    "if status != 'InService':\n",
    "    raise Exception('Endpoint creation failed.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test endpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#file_name = 'image.jpg'\n",
    "file_name = 'traindata_cat_dog_images_20/0.jpg'\n",
    "# test image\n",
    "from IPython.display import Image\n",
    "Image(file_name) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import boto3\n",
    "runtime = boto3.Session().client(service_name='runtime.sagemaker') \n",
    "\n",
    "with open(file_name, 'rb') as f:\n",
    "    payload = f.read()\n",
    "    payload = bytearray(payload)\n",
    "response = runtime.invoke_endpoint(EndpointName=endpoint_name, \n",
    "                                   ContentType='application/x-image', \n",
    "                                   Body=payload)\n",
    "result = response['Body'].read()\n",
    "# result will be in json format and convert it to ndarray\n",
    "result = json.loads(result)\n",
    "# find the class with maximum probability and print the class index\n",
    "index = np.argmax(result)\n",
    "object_categories = ['cat', 'dog']\n",
    "print(\"Result: label - \" + object_categories[index] + \", probability - \" + str(result[index]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Delete the endpoint\n",
    "\n",
    "After you have finished with this example, remember to delete the endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_client.delete_endpoint(EndpointName=endpoint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_mxnet_p36",
   "language": "python",
   "name": "conda_mxnet_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}