{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Visualize AWS Samples Mask-RCNN Detection Results\n", "\n", "This notebook visualizes detection results predicted by a trained [AWS Samples Mask-RCNN](https://github.com/aws-samples/mask-rcnn-tensorflow) model. \n", "\n", "## Load trained model\n", "\n", "First we define the system path for Python classes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "\n", "sys.path.append('/mask-rcnn-tensorflow/MaskRCNN')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we initialize the ResNet FPN Mask RCNN model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from model.generalized_rcnn import ResNetFPNModel\n", "\n", "# create a mask r-cnn model\n", "mask_rcnn_model = ResNetFPNModel()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we specify the `model_dir` below and load the `trained model`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import glob\n", "\n", "# find best pre-trained model checkpoint\n", "model_dir= f\"{os.environ['LOGDIR']}/train_log/maskrcnn/\"\n", "\n", "print(f\"Using model directory: {model_dir}\")\n", "model_search_path = os.path.join(model_dir, \"model-*.index\" )\n", "model_files = glob.glob(model_search_path)\n", "\n", "def sort_key(path):\n", " index = path.rindex(\"model-\")\n", " key = int(path[index+6:-6])\n", " return key\n", "\n", "\n", "try:\n", " model_files = sorted(model_files, key = sort_key)\n", " latest_trained_model = model_files[-1]\n", "\n", " trained_model = latest_trained_model[:-6]\n", " print(f'Using model: {trained_model}')\n", "except:\n", " print(f\"No model found in: {model_dir}\")\n", " \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we initialize the model configuration to match the configuration we used for training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dataset import DetectionDataset\n", "from config import finalize_configs, config as cfg\n", "\n", "# setup config\n", "cfg.MODE_FPN = True\n", "cfg.MODE_MASK = True\n", "cfg.DATA.BASEDIR = '/efs/data/'\n", "DetectionDataset()\n", "\n", "finalize_configs(is_training=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Predictor\n", "\n", "Next we create a predictor that uses our trained model to make predictions on test images." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tensorpack.predict.base import OfflinePredictor\n", "from tensorpack.tfutils.sessinit import get_model_loader\n", "from tensorpack.predict.config import PredictConfig\n", "\n", "# Create an inference predictor \n", "predictor = OfflinePredictor(PredictConfig(\n", " model=mask_rcnn_model,\n", " session_init=get_model_loader(trained_model),\n", " input_names=['images', 'orig_image_dims'],\n", " output_names=[\n", " 'fastrcnn_all_scores',\n", " 'output/boxes',\n", " 'output/scores',\n", " 'output/labels',\n", " 'output/masks'\n", " ]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download COCO Test 2017 dataset\n", "Below we download the [COCO 2017 Test dataset](http://cocodataset.org/#download) and extract the downloaded dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget -O /tmp/test2017.zip http://images.cocodataset.org/zips/test2017.zip" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!unzip -q -d /tmp/ /tmp/test2017.zip\n", "!rm /tmp/test2017.zip" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define visualization helper functions\n", "\n", "Next we define helper functions to visualize the results.\n", "\n", "The function `get_mask` resizes the detection mask to the size of the object bounding box and applies the mask to the image." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tensorpack.utils.palette import PALETTE_RGB\n", "\n", "def get_mask(img, box, mask, threshold=.5):\n", " box = box.astype(int)\n", " color = PALETTE_RGB[np.random.choice(len(PALETTE_RGB))][::-1]\n", " a_mask = np.stack([(cv2.resize(mask, (box[2]-box[0], box[3]-box[1])) > threshold).astype(np.int8)]*3, axis=2)\n", " sub_image = img[box[1]:box[3],box[0]:box[2],:].astype(np.uint8)\n", " sub_image = np.where(a_mask==1, sub_image * (1 - 0.5) + color * 0.5, sub_image)\n", " new_image = img.copy()\n", " new_image[box[1]:box[3],box[0]:box[2],:] = sub_image\n", " return new_image\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function `show_detection_results` applies the masks and bounding boxes to the image and visualizes the detection results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib import patches\n", "\n", "def show_detection_results(img, boxes, scores, labels, masks, score_threshold=.7, mask_threshold=0.5):\n", " fig,ax = plt.subplots(figsize=(img.shape[1]//50, img.shape[0]//50))\n", " \n", " for bbox, score, label, mask in zip(boxes, scores, labels, masks):\n", " if score >= score_threshold:\n", " img = get_mask(img, bbox, mask, mask_threshold)\n", " \n", " # Show bounding box\n", " x1, y1, x2, y2 = bbox\n", " bbox_y = y1\n", " bbox_x = x1\n", " \n", " bbox_w = (x2 - x1)\n", " bbox_h = (y2 - y1)\n", "\n", " color = PALETTE_RGB[np.random.choice(len(PALETTE_RGB))][::-1]/255\n", " box_patch = patches.Rectangle((bbox_x, bbox_y), bbox_w, bbox_h, \n", " linewidth=1,\n", " alpha=0.7, linestyle=\"dashed\",\n", " edgecolor=color, facecolor='none')\n", " ax.add_patch(box_patch)\n", " class_name=cfg.DATA.CLASS_NAMES[label]\n", " ax.text(bbox_x, bbox_y + 8, class_name,\n", " color='w', size=11, backgroundcolor=\"none\")\n", " \n", " ax.imshow(img.astype(int))\n", " plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load test image\n", "\n", "Next, we find a random image to test from COCO 2017 Test dataset. You can come back to this cell when you want to load the next test image." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import random\n", "\n", "test2017_dir=os.path.join(\"/tmp\", \"test2017\")\n", "img_id=random.choice(os.listdir(test2017_dir))\n", "img_local_path = os.path.join(test2017_dir,img_id)\n", "print(img_local_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we load the random test image and convert the image color scheme from BGR to RBG." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import cv2\n", "\n", "img=cv2.imread(img_local_path, cv2.IMREAD_COLOR)\n", "print(img.shape)\n", "img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next. we show the raw image we randomly loaded from the COCO 2017 Test dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig,ax = plt.subplots(figsize=(img.shape[1]//50, img.shape[0]//50))\n", "ax.imshow(img.astype(int))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predict\n", "Next, we use the predictor to predict detection results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "all_scores, final_boxes, final_scores, final_labels, masks = predictor(np.expand_dims(img, axis=0),\n", " np.expand_dims(np.array(img.shape), axis=0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize \n", "Next, we visualize the detection results on our image." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_detection_results(img, final_boxes, final_scores, final_labels, masks)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Go back to Load test image cell if you want to test more images." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }