{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Workshop - Human in the Loop for SageMaker Models - Module 2\n",
    "\n",
    "In this module, you will learn how you can use **Amazon A2I with an Amazon SageMaker Hosted Endpoint.**.\n",
    "\n",
    "Amazon Augmented AI (Amazon A2I) makes it easy to build the workflows required for human review of ML predictions. Amazon A2I brings human review to all developers, removing the undifferentiated heavy lifting associated with building human review systems or managing large numbers of human reviewers. \n",
    "\n",
    "You can create your own workflows for ML models built on Amazon SageMaker or any other tools. Using Amazon A2I, you can allow human reviewers to step in when a model is unable to make a high confidence prediction or to audit its predictions on an on-going basis. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import io\n",
    "import uuid\n",
    "import time\n",
    "import random\n",
    "import numpy as np\n",
    "import json\n",
    "import boto3\n",
    "import botocore\n",
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as patches    \n",
    "import matplotlib.image as mpimg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to set up the following data:\n",
    "* `region` - Region to call A2I.\n",
    "* `BUCKET` - A S3 bucket accessible by the given role\n",
    "    * Used to store the sample images & output results\n",
    "    * Must be within the same region A2I is called from\n",
    "* `role` - The IAM role used as part of StartHumanLoop. By default, this notebook will use the execution role\n",
    "* `workteam` - Group of people to send the work to"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "region = 'us-east-1'\n",
    "role = get_execution_role()\n",
    "sess = sagemaker.Session()\n",
    "BUCKET = sess.default_bucket()\n",
    "OUTPUT_PATH = f's3://{BUCKET}/a2i-results'\n",
    "MODEL_PATH = f's3://{BUCKET}/model'\n",
    "WORKFORCE_ARN = '' # CREATE THE WORKFORCE BEFORE RUNNING THE NEXT CELLS\n",
    "endpoint_name = 'DEMO-ObjectDetection-endpoint-pretrained'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Model and Create an Endpoint\n",
    "\n",
    "For workshop purpose we are going to deploy a pre-trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download pre-trained model\n",
    "source_model_data_s3_uri = 's3://aws-sagemaker-augmented-ai-example/model/model.tar.gz'\n",
    "\n",
    "!aws s3 cp {source_model_data_s3_uri} {MODEL_PATH}/model.tar.gz\n",
    "\n",
    "model_data_s3_uri = f'{MODEL_PATH}/model.tar.gz'\n",
    "\n",
    "image = sagemaker.image_uris.retrieve('object-detection', region, version='1')\n",
    "\n",
    "model = sagemaker.model.Model(image, \n",
    "                              model_data = model_data_s3_uri,\n",
    "                              role = role,\n",
    "                              predictor_cls = sagemaker.predictor.Predictor,\n",
    "                              sagemaker_session = sess)\n",
    "\n",
    "object_detector = model.deploy(initial_instance_count = 1,\n",
    "                               instance_type = 'ml.m5.xlarge',\n",
    "                               endpoint_name = endpoint_name,\n",
    "                               serializer = sagemaker.serializers.IdentitySerializer('image/jpeg'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test Endpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### object detection helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_detection(img_file, dets, classes=[], thresh=0.6):\n",
    "    \"\"\"\n",
    "    visualize detections in one image\n",
    "    Parameters:\n",
    "    ----------\n",
    "    img : numpy.array\n",
    "        image, in bgr format\n",
    "    dets : numpy.array\n",
    "        ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])\n",
    "        each row is one object\n",
    "    classes : tuple or list of str\n",
    "        class names\n",
    "    thresh : float\n",
    "        score threshold\n",
    "    \"\"\"\n",
    "    img=mpimg.imread(img_file)\n",
    "    f, ax = plt.subplots(1, 1)\n",
    "    ax.imshow(img)\n",
    "    height = img.shape[0]\n",
    "    width = img.shape[1]\n",
    "    colors = dict()\n",
    "    output = []\n",
    "    for det in dets:\n",
    "        (klass, score, x0, y0, x1, y1) = det\n",
    "        cls_id = int(klass)\n",
    "        class_name = str(cls_id)\n",
    "        if classes and len(classes) > cls_id:\n",
    "            class_name = classes[cls_id]\n",
    "        output.append([class_name, score])\n",
    "        if score < thresh:\n",
    "            continue\n",
    "        if cls_id not in colors:\n",
    "            colors[cls_id] = (random.random(), random.random(), random.random())\n",
    "        xmin = int(x0 * width)\n",
    "        ymin = int(y0 * height)\n",
    "        xmax = int(x1 * width)\n",
    "        ymax = int(y1 * height)\n",
    "        rect = patches.Rectangle((xmin, ymin), xmax - xmin,\n",
    "                                 ymax - ymin, fill=False,\n",
    "                                 edgecolor=colors[cls_id],\n",
    "                                 linewidth=3.5)\n",
    "        ax.add_patch(rect)\n",
    "\n",
    "\n",
    "        ax.text(xmin, ymin - 2,\n",
    "                '{:s} {:.3f}'.format(class_name, score),\n",
    "                bbox=dict(facecolor=colors[cls_id], alpha=0.5),\n",
    "                          fontsize=12, color='white')\n",
    "\n",
    "    return f, output\n",
    "    \n",
    "def load_and_predict(file_name, predictor, threshold=0.5):\n",
    "    \"\"\"\n",
    "    load an image, make object detection to an predictor, and visualize detections\n",
    "    Parameters:\n",
    "    ----------\n",
    "    file_name : str\n",
    "        image file location, in str format\n",
    "    predictor : sagemaker.predictor.RealTimePredictor\n",
    "        a predictor loaded from hosted endpoint\n",
    "    threshold : float\n",
    "        score threshold for bounding box display\n",
    "    \"\"\"\n",
    "    with open(file_name, 'rb') as image:\n",
    "        f = image.read()\n",
    "        b = bytearray(f)\n",
    "    results = predictor.predict(b)\n",
    "    detections = json.loads(results)\n",
    "    \n",
    "    fig, detection_filtered = visualize_detection(file_name, detections['prediction'], \n",
    "                                                   object_categories, threshold)\n",
    "    \n",
    "    return results, detection_filtered, fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', \n",
    "                     'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', \n",
    "                     'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_photos_index = ['980382', '276517', '1571457']\n",
    "\n",
    "if not os.path.isdir('sample-a2i-images'):\n",
    "    os.mkdir('sample-a2i-images')\n",
    "    \n",
    "for ind in test_photos_index:\n",
    "    !curl https://images.pexels.com/photos/{ind}/pexels-photo-{ind}.jpeg > sample-a2i-images/pexels-photo-{ind}.jpeg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_photos = ['sample-a2i-images/pexels-photo-980382.jpeg', # motorcycle\n",
    "               'sample-a2i-images/pexels-photo-276517.jpeg', # bicycle\n",
    "               'sample-a2i-images/pexels-photo-1571457.jpeg'] # sofa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = load_and_predict(test_photos[2], object_detector, threshold=0.2)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_results = json.loads(results.decode('utf8'))['prediction']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_detection(test_photos[2], dict_results, object_categories, thresh=0.4)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Probability of 0.465 is considered quite low in modern computer vision and there is a mislabeling. This is due to the fact that the SSD model was under-trained for demonstration purposes. However this under-trained model serves as a perfect example of brining human reviewers when a model is unable to make a high confidence prediction."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Human Review"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_client = boto3.client('sagemaker', region)\n",
    "a2i = boto3.client('sagemaker-a2i-runtime')\n",
    "s3 = boto3.client('s3', region)\n",
    "\n",
    "# Flow definition name - this value is unique per account and region. You can also provide your own value here.\n",
    "flowDefinitionName = 'fd-sagemaker-object-detection-demo'\n",
    "# Task UI name - this value is unique per account and region. You can also provide your own value here.\n",
    "taskUIName = 'ui-sagemaker-object-detection-demo'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Human Task UI\n",
    "\n",
    "Create a human task UI resource, giving a UI template in liquid html. This template will be rendered to the human workers whenever human loop is required.\n",
    "\n",
    "For pre built UIs, check: https://github.com/aws-samples/amazon-a2i-sample-task-uis.\n",
    "\n",
    "We will be taking an [object detection UI](https://github.com/aws-samples/amazon-a2i-sample-task-uis/blob/master/images/bounding-box.liquid.html) and filling in the object categories in the `labels` variable in the template."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "template = r\"\"\"\n",
    "<script src=\"https://assets.crowd.aws/crowd-html-elements.js\"></script>\n",
    "\n",
    "<crowd-form>\n",
    "  <crowd-bounding-box\n",
    "    name=\"annotatedResult\"\n",
    "    src=\"{{ task.input.taskObject | grant_read_access }}\"\n",
    "    header=\"Draw bounding boxes around all the objects in this image\"\n",
    "    labels=\"['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']\"\n",
    "  >\n",
    "    <full-instructions header=\"Bounding Box Instructions\" >\n",
    "      <p>Use the bounding box tool to draw boxes around the requested target of interest:</p>\n",
    "      <ol>\n",
    "        <li>Draw a rectangle using your mouse over each instance of the target.</li>\n",
    "        <li>Make sure the box does not cut into the target, leave a 2 - 3 pixel margin</li>\n",
    "        <li>\n",
    "          When targets are overlapping, draw a box around each object,\n",
    "          include all contiguous parts of the target in the box.\n",
    "          Do not include parts that are completely overlapped by another object.\n",
    "        </li>\n",
    "        <li>\n",
    "          Do not include parts of the target that cannot be seen,\n",
    "          even though you think you can interpolate the whole shape of the target.\n",
    "        </li>\n",
    "        <li>Avoid shadows, they're not considered as a part of the target.</li>\n",
    "        <li>If the target goes off the screen, label up to the edge of the image.</li>\n",
    "      </ol>\n",
    "    </full-instructions>\n",
    "\n",
    "    <short-instructions>\n",
    "      Draw boxes around the requested target of interest.\n",
    "    </short-instructions>\n",
    "  </crowd-bounding-box>\n",
    "</crowd-form>\n",
    "\"\"\"\n",
    "\n",
    "def create_task_ui():\n",
    "    '''\n",
    "    Creates a Human Task UI resource.\n",
    "\n",
    "    Returns:\n",
    "    struct: HumanTaskUiArn\n",
    "    '''\n",
    "    response = sagemaker_client.create_human_task_ui(\n",
    "        HumanTaskUiName=taskUIName,\n",
    "        UiTemplate={'Content': template})\n",
    "    return response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create task UI\n",
    "humanTaskUiResponse = create_task_ui()\n",
    "humanTaskUiArn = humanTaskUiResponse['HumanTaskUiArn']\n",
    "print(humanTaskUiArn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create the Flow Definition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we're going to create a flow definition definition. Flow Definitions allow us to specify:\n",
    "\n",
    "* The workforce that your tasks will be sent to.\n",
    "* The instructions that your workforce will receive. This is called a worker task template.\n",
    "* The configuration of your worker tasks, including the number of workers that receive a task and time limits to complete tasks.\n",
    "* Where your output data will be stored.\n",
    "\n",
    "This demo is going to use the API, but you can optionally create this workflow definition in the console as well. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_workflow_definition_response = sagemaker_client.create_flow_definition(\n",
    "        FlowDefinitionName= flowDefinitionName,\n",
    "        RoleArn= role,\n",
    "        HumanLoopConfig= {\n",
    "            \"WorkteamArn\": WORKFORCE_ARN,\n",
    "            \"HumanTaskUiArn\": humanTaskUiArn,\n",
    "            \"TaskCount\": 1,\n",
    "            \"TaskDescription\": \"Identify and locate the object in an image.\",\n",
    "            \"TaskTitle\": \"Object detection a2i demo\"\n",
    "        },\n",
    "        OutputConfig={\n",
    "            \"S3OutputPath\" : OUTPUT_PATH\n",
    "        }\n",
    "    )\n",
    "flowDefinitionArn = create_workflow_definition_response['FlowDefinitionArn'] # let's save this ARN for future use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Describe flow definition - status should be active\n",
    "for x in range(60):\n",
    "    describeFlowDefinitionResponse = sagemaker_client.describe_flow_definition(FlowDefinitionName=flowDefinitionName)\n",
    "    print(describeFlowDefinitionResponse['FlowDefinitionStatus'])\n",
    "    if (describeFlowDefinitionResponse['FlowDefinitionStatus'] == 'Active'):\n",
    "        print(\"Flow Definition is active\")\n",
    "        break\n",
    "    time.sleep(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have setup our Flow Definition, we are ready to call our object detection endpoint on SageMaker and start our human loops. In this tutorial, we are interested in starting a HumanLoop only if the highest prediction probability score returned by our model for objects detected is less than 50%. \n",
    "\n",
    "So, with a bit of logic, we can check the response for each call to the SageMaker endpoint using `load_and_predict` helper function, and if the highest score is less than 50%, we will kick off a HumanLoop to engage our workforce for a human review. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the sample images to s3 bucket for a2i UI to display\n",
    "!aws s3 sync ./sample-a2i-images/ s3://{BUCKET}/a2i-results/sample-a2i-images/\n",
    "    \n",
    "human_loops_started = []\n",
    "SCORE_THRESHOLD = .50\n",
    "for fname in test_photos:\n",
    "    # Call SageMaker endpoint and not display any object detected with probability lower than 0.4.\n",
    "    response, score_filtered, fig = load_and_predict(fname, object_detector, threshold=0.4)\n",
    "    \n",
    "    # Sort by prediction score so that the first item has the highest probability\n",
    "    score_filtered.sort(key=lambda x: x[1], reverse=True)\n",
    "\n",
    "    # Our condition for triggering a human review\n",
    "    if (score_filtered[0][1] < SCORE_THRESHOLD):\n",
    "        s3_fname='s3://%s/a2i-results/%s' % (BUCKET, fname)\n",
    "        print(s3_fname)\n",
    "        humanLoopName = str(uuid.uuid4())\n",
    "        inputContent = {\n",
    "            \"initialValue\": score_filtered[0][0],\n",
    "            \"taskObject\": s3_fname # the s3 object will be passed to the worker task UI to render\n",
    "        }\n",
    "        # start an a2i human review loop with an input\n",
    "        start_loop_response = a2i.start_human_loop(\n",
    "            HumanLoopName=humanLoopName,\n",
    "            FlowDefinitionArn=flowDefinitionArn,\n",
    "            HumanLoopInput={\n",
    "                \"InputContent\": json.dumps(inputContent)\n",
    "            }\n",
    "        )\n",
    "        human_loops_started.append(humanLoopName)\n",
    "        print(f'Object detection Confidence Score of %s is less than the threshold of %.2f' % (score_filtered[0][0], SCORE_THRESHOLD))\n",
    "        print(f'Starting human loop with name: {humanLoopName}  \\n')\n",
    "    else:\n",
    "        print(f'Object detection Confidence Score of %s is above than the threshold of %.2f' % (score_filtered[0][0], SCORE_THRESHOLD))\n",
    "        print('No human loop created. \\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check Status of Human Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "completed_human_loops = []\n",
    "for human_loop_name in human_loops_started:\n",
    "    resp = a2i.describe_human_loop(HumanLoopName=human_loop_name)\n",
    "    print(f'HumanLoop Name: {human_loop_name}')\n",
    "    print(f'HumanLoop Status: {resp[\"HumanLoopStatus\"]}')\n",
    "    print(f'HumanLoop Output Destination: {resp[\"HumanLoopOutput\"]}')\n",
    "    print('\\n')\n",
    "    \n",
    "    if resp[\"HumanLoopStatus\"] == \"Completed\":\n",
    "        completed_human_loops.append(resp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we are using private workteam, we should go to the labling UI to perform the inspection ourselves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "workteamName = WORKFORCE_ARN[WORKFORCE_ARN.rfind('/') + 1:]\n",
    "print(\"Navigate to the private worker portal and do the tasks. Make sure you've invited yourself to your workteam!\")\n",
    "print('https://' + sagemaker_client.describe_workteam(WorkteamName=workteamName)['Workteam']['SubDomain'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check Status of Human Loop Again"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "completed_human_loops = []\n",
    "for human_loop_name in human_loops_started:\n",
    "    resp = a2i.describe_human_loop(HumanLoopName=human_loop_name)\n",
    "    print(f'HumanLoop Name: {human_loop_name}')\n",
    "    print(f'HumanLoop Status: {resp[\"HumanLoopStatus\"]}')\n",
    "    print(f'HumanLoop Output Destination: {resp[\"HumanLoopOutput\"]}')\n",
    "    print('\\n')\n",
    "    \n",
    "    if resp[\"HumanLoopStatus\"] == \"Completed\":\n",
    "        completed_human_loops.append(resp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### View Task Results  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once work is completed, Amazon A2I stores results in your S3 bucket and sends a Cloudwatch event. Your results should be available in the S3 OUTPUT_PATH when all work is completed. Note that the human answer, the label and the bounding box, is returned and saved in the json file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import pprint\n",
    "\n",
    "pp = pprint.PrettyPrinter(indent=4)\n",
    "\n",
    "for resp in completed_human_loops:\n",
    "    splitted_string = re.split('s3://' +  BUCKET + '/', resp['HumanLoopOutput']['OutputS3Uri'])\n",
    "    output_bucket_key = splitted_string[1]\n",
    "\n",
    "    response = s3.get_object(Bucket=BUCKET, Key=output_bucket_key)\n",
    "    content = response[\"Body\"].read()\n",
    "    json_output = json.loads(content)\n",
    "    pp.pprint(json_output)\n",
    "    print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean Up (finish module 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# object_detector.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (Data Science)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}