{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize AWS Samples Mask-RCNN Detection Results\n",
    "\n",
    "This notebook visualizes detection results predicted by a trained [AWS Samples Mask-RCNN](https://github.com/aws-samples/mask-rcnn-tensorflow) model. \n",
    "\n",
    "## Load trained model\n",
    "\n",
    "First we define the system path for Python classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "os.environ['TENSORPACK_FP16'] = 'true'\n",
    "sys.path.append('/mask-rcnn-tensorflow/MaskRCNN')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we initialize the ResNet FPN Mask RCNN model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model.generalized_rcnn import ResNetFPNModel\n",
    "\n",
    "# create a mask r-cnn model\n",
    "mask_rcnn_model = ResNetFPNModel(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we specify the `model_dir` below and load the `trained model`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "\n",
    "# find best pre-trained model checkpoint\n",
    "model_dir=  f\"{os.environ['LOGDIR']}/train_log/maskrcnn/\"\n",
    "\n",
    "print(f\"Using model directory: {model_dir}\")\n",
    "model_search_path = os.path.join(model_dir, \"model-*.index\" )\n",
    "model_files = glob.glob(model_search_path)\n",
    "\n",
    "def sort_key(path):\n",
    "    index = path.rindex(\"model-\")\n",
    "    key = int(path[index+6:-6])\n",
    "    return key\n",
    "\n",
    "try:\n",
    "    model_files = sorted(model_files, key = sort_key)\n",
    "    latest_trained_model = model_files[-1]\n",
    "\n",
    "    trained_model = latest_trained_model[:-6]\n",
    "    print(f'Using model: {trained_model}')\n",
    "except:\n",
    "    print(f\"No model found in: {model_dir}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we initialize the model configuration to match the configuration we used for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import DetectionDataset\n",
    "from config import finalize_configs, config as cfg\n",
    "\n",
    "# setup config\n",
    "cfg.MODE_FPN = True\n",
    "cfg.MODE_MASK = True\n",
    "cfg.DATA.BASEDIR = '/efs/data/'\n",
    "DetectionDataset()\n",
    "\n",
    "finalize_configs(is_training=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Predictor\n",
    "\n",
    "Next we create a predictor that uses our trained model to make predictions on test images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorpack.predict.base import OfflinePredictor\n",
    "from tensorpack.tfutils.sessinit import get_model_loader\n",
    "from tensorpack.predict.config import PredictConfig\n",
    "\n",
    "# Create an inference predictor           \n",
    "predictor = OfflinePredictor(PredictConfig(\n",
    "        model=mask_rcnn_model,\n",
    "        session_init=get_model_loader(trained_model),\n",
    "        input_names=['images', 'orig_image_dims'],\n",
    "        output_names=[\n",
    "            'generate_{}_proposals_topk_per_image/boxes'.format('fpn' if cfg.MODE_FPN else 'rpn'),\n",
    "            'generate_{}_proposals_topk_per_image/scores'.format('fpn' if cfg.MODE_FPN else 'rpn'),\n",
    "            'fastrcnn_all_scores',\n",
    "            'output/boxes',\n",
    "            'output/scores',\n",
    "            'output/labels',\n",
    "            'output/masks'\n",
    "        ]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download COCO Test 2017 dataset\n",
    "Below we download the [COCO 2017 Test dataset](http://cocodataset.org/#download) and extract the downloaded dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget -O /tmp/test2017.zip http://images.cocodataset.org/zips/test2017.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!unzip -q -d /tmp/ /tmp/test2017.zip\n",
    "!rm  /tmp/test2017.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define visualization helper functions\n",
    "\n",
    "Next we define helper functions to visualize the results.\n",
    "\n",
    "The function `get_mask` resizes the detection mask to the size of the object bounding box and applies the mask to the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorpack.utils.palette import PALETTE_RGB\n",
    "\n",
    "def get_mask(img, box, mask, threshold=.5):\n",
    "    box = box.astype(int)\n",
    "    color = PALETTE_RGB[np.random.choice(len(PALETTE_RGB))][::-1]\n",
    "    a_mask = np.stack([(cv2.resize(mask, (box[2]-box[0], box[3]-box[1])) > threshold).astype(np.int8)]*3, axis=2)\n",
    "    sub_image = img[box[1]:box[3],box[0]:box[2],:].astype(np.uint8)\n",
    "    sub_image = np.where(a_mask==1, sub_image * (1 - 0.5) + color * 0.5, sub_image)\n",
    "    new_image = img.copy()\n",
    "    new_image[box[1]:box[3],box[0]:box[2],:] = sub_image\n",
    "    return new_image\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function `show_detection_results` applies the  masks and bounding boxes to the image and visualizes the detection results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import patches\n",
    "\n",
    "def show_detection_results(img, boxes, scores, labels, masks,  score_threshold=.7, mask_threshold=0.5):\n",
    "    fig,ax = plt.subplots(figsize=(img.shape[1]//50, img.shape[0]//50))\n",
    "    \n",
    "    for bbox, score, label, mask in zip(boxes, scores, labels, masks):\n",
    "        if score >= score_threshold:\n",
    "            img = get_mask(img, bbox, mask, mask_threshold)\n",
    "            \n",
    "            # Show bounding box\n",
    "            x1, y1, x2, y2 = bbox\n",
    "            bbox_y = y1\n",
    "            bbox_x = x1\n",
    "            \n",
    "            bbox_w = (x2 - x1)\n",
    "            bbox_h = (y2 - y1)\n",
    "\n",
    "            color = PALETTE_RGB[np.random.choice(len(PALETTE_RGB))][::-1]/255\n",
    "            box_patch = patches.Rectangle((bbox_x, bbox_y), bbox_w, bbox_h, \n",
    "                        linewidth=1,\n",
    "                        alpha=0.7, linestyle=\"dashed\",\n",
    "                        edgecolor=color, facecolor='none')\n",
    "            ax.add_patch(box_patch)\n",
    "            class_name=cfg.DATA.CLASS_NAMES[label]\n",
    "            ax.text(bbox_x, bbox_y + 8, class_name,\n",
    "                color='w', size=11, backgroundcolor=\"none\")\n",
    "            \n",
    "    ax.imshow(img.astype(int))\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load test image\n",
    "\n",
    "Next, we find a random image to test from COCO 2017 Test dataset. You can come back to this cell when you want to load the next test image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "\n",
    "test2017_dir=os.path.join(\"/tmp\", \"test2017\")\n",
    "img_id=random.choice(os.listdir(test2017_dir))\n",
    "img_local_path = os.path.join(test2017_dir,img_id)\n",
    "print(img_local_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we load the random test image and convert the image color scheme from BGR to RBG."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "\n",
    "img=cv2.imread(img_local_path, cv2.IMREAD_COLOR)\n",
    "print(img.shape)\n",
    "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next. we show the raw image we randomly loaded from the COCO 2017 Test dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig,ax = plt.subplots(figsize=(img.shape[1]//50, img.shape[0]//50))\n",
    "ax.imshow(img.astype(int))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predict\n",
    "Next, we use the predictor to predict detection results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "_, _, all_scores, final_boxes, final_scores, final_labels, masks = predictor(np.expand_dims(img, axis=0),\n",
    "                                                            np.expand_dims(np.array(img.shape), axis=0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize \n",
    "Next, we visualize the detection results on our image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_detection_results(img, final_boxes, final_scores, final_labels, masks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Go back to <b>Load test image</b> cell if you want to test more images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}