{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Transfer Learning of YoloV3 with GluonCV\n", "## Introduction\n", "\n", "This is an end-to-end example of GluonCV YoloV3 model training inside of Amazon SageMaker notebook using Script mode and then compiling the trained model using SageMaker Neo runtime. In this demo, we will demonstrate how to finetune a model using the autonomous driving dataset labeled with SageMaker GroundTruth. We will also cover how to convert a SageMaker GroundTruth manifest file into recordIO format used for model training. Finally, we will demonstrate how to optimize this trained model using SageMaker Neo and deploy the compiled model.\n", "\n", "***This notebook is for demonstration purpose only and does not create a state-of-the-art ML Model. Please fine tune the training parameters based on your own dataset.***" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!~/anaconda3/envs/mxnet_p36/bin/pip install --upgrade sagemaker" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "\n", "current_version = sagemaker.__version__\n", "print(\"current SageMaker SDK Version: {}\".format(current_version))\n", "if current_version.split(\".\")[0] == \"1\":\n", " raise Exception(\n", " \"Please upgrade sagemaker SDK by running the above cell while ensuring kernel name is the same as the one being used. Restart the kernel after upgrade.\"\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install jsonlines\n", "!pip install gluoncv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import os\n", "import json\n", "import imageio\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import boto3\n", "import random\n", "import jsonlines\n", "import mxnet\n", "import pandas as pd\n", "import PIL.Image\n", "import imageio\n", "from sagemaker.mxnet import MXNet\n", "from sagemaker.tuner import (\n", " HyperparameterTuner,\n", " IntegerParameter,\n", " ContinuousParameter,\n", " CategoricalParameter,\n", ")\n", "from sagemaker.inputs import TrainingInput\n", "from sagemaker.estimator import Estimator\n", "from sagemaker.serializers import JSONSerializer\n", "from sagemaker.deserializers import JSONDeserializer\n", "\n", "role = sagemaker.get_execution_role()\n", "sess = sagemaker.Session()\n", "BUCKET = sess.default_bucket()\n", "region = boto3.session.Session().region_name" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.mkdir('./data/')\n", "os.mkdir('./data/train')\n", "os.mkdir('./data/val')\n", "os.mkdir('./data/test')\n", "os.mkdir('./img/')\n", "os.mkdir('./img/labelled-images')\n", "os.mkdir('./img/raw-images')\n", "os.mkdir('./img/test-images')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have provided a sample GroundTruth Manifest File in this repository under `ground_truth/output.manifest` to run the notebook.\n", "If you have executed your own Ground Truth Job on Bounding Boxes and plan to use it to run this notebook, please insert the name of your GroundTruth Job below" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "job_name = \"av-labeling\" " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Manifest file of the GroundTruth Job contains the source-reference describing the S3 path where the input image is stored. Therefore we need to update the S3 path with the bucket in your account that has been launched with the CDK Stack:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3 = boto3.resource('s3')\n", "s3_bucket_raw_images = [bucket.name for bucket in s3.buckets.all() if bucket.name.startswith('my-vsi-rosbag-stack-destbucket')][0]\n", "s3 = boto3.resource('s3')\n", "my_bucket = s3.Bucket(s3_bucket_raw_images)\n", "s3_folder_raw_images= \"test-vehicle-01/072021/2020-11-19-22-21-36_1/\"\n", "print(f\"The raw images are expected in this bucket {s3_bucket_raw_images} with an s3 key {s3_folder_raw_images}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Show some example images to check annotation quality\n", "Firstly, let's have a look at the annotated data from SageMaker GroundTruth. Therefore, we will have to load the Output Manifest file provided in this repository and update the bucket name in the s3 keys to the corresponding bucket name defined above:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(\"ground_truth/output.manifest\", \"r\") as f:\n", " output = [json.loads(line.replace('placeholder-bucket',s3_bucket_raw_images).strip()) for line in f.readlines()]\n", "print(\"there are {} images annotated\".format(len(output)))\n", "\n", "with open('ground_truth/output.manifest', 'w') as outfile:\n", " for item in output:\n", " outfile.write(\"%s\\n\" % json.dumps(item))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Afterwards we can map the annotations to their images and visualize sample annotations. We will also save the labelled images in the folder `img/labelled-image/` if you wish to explore images further." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ground_truth.ground_truth_od import BoundingBox, BoxedImage\n", "\n", "keys = list(output[0].keys())\n", "metakey = keys[np.where([(\"-metadata\" in k) for k in keys])[0][0]]\n", "output_images = []\n", "consolidated_boxes = []\n", "\n", "for datum_id, datum in enumerate(output):\n", " image_size = datum[job_name][\"image_size\"][0]\n", " box_annotations = datum[job_name][\"annotations\"]\n", " uri = datum[\"source-ref\"]\n", " box_confidences = datum[metakey][\"objects\"]\n", " human = int(datum[metakey][\"human-annotated\"] == \"yes\")\n", " image = BoxedImage(id=datum_id, size=image_size, uri=uri)\n", " boxes = []\n", " for i, annotation in enumerate(box_annotations):\n", " box = BoundingBox(image_id=datum_id, boxdata=annotation)\n", " box.confidence = box_confidences[i][\"confidence\"]\n", " box.image = image\n", " box.human = human\n", " boxes.append(box)\n", " consolidated_boxes.append(box)\n", " image.consolidated_boxes = boxes\n", " image.human = human\n", " output_images.append(image)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n = 10 # specify the number of images that you would like to check\n", "rand_images = random.sample(output_images, n)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for row, img in enumerate(rand_images):\n", " fig, axes = plt.subplots(1, 1, figsize=(10, 10), dpi=100)\n", " img.download(\"./img/raw-images/\")\n", " axes.set_title(img.uri, fontdict={\"fontsize\": 8})\n", " img.plot_consolidated_bbs(axes)\n", " fig.savefig(\n", " \"./img/labelled-images/\" + \"_\".join(img.uri.split(\"/\")[3::]),\n", " bbox_inches=\"tight\",\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Create RecordIO Format\n", "\n", "Since we are happy with the annotation quality, we can continue and prepare the dataset for training.\n", "The preferred Object Detection Format for GluonCV is based on LST file format, which is convertible to the [RecordIO design](https://mxnet.apache.org/versions/1.8.0/api/architecture/note_data_loading) that provides faster disk access and compact storage. GluonCV provides a [tutorial](https://cv.gluon.ai/build/examples_datasets/detection_custom.html) on how to convert a LST file format into a RecordIO file, thus we will not dive deep in this blogpost. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "file = \"ground_truth/output.manifest\"\n", "train_lst = \"data/train/train.lst\"\n", "val_lst = \"data/val/val.lst\"\n", "test_lst = \"data/test/test.lst\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with jsonlines.open(file, \"r\") as reader:\n", " lines = list(reader)\n", " # Shuffle data in place.\n", " np.random.shuffle(lines)\n", "\n", "train_data, validation_data, test_data = np.split(\n", " lines, [int(0.6 * len(lines)), int(0.8 * len(lines))]\n", ")\n", "\n", "with open(train_lst, \"w\") as f:\n", " for line in train_data:\n", " f.write(json.dumps(line))\n", " f.write(\"\\n\")\n", "\n", "with open(val_lst, \"w\") as f:\n", " for line in validation_data:\n", " f.write(json.dumps(line))\n", " f.write(\"\\n\")\n", "\n", "with open(test_lst, \"w\") as f:\n", " for line in test_data:\n", " f.write(json.dumps(line))\n", " f.write(\"\\n\")\n", "\n", "print(\n", " \"training samples: {}, validation samples: {}, test samples: {}\".format(\n", " len(train_data), len(validation_data), len(test_data)\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def write_line(img_path, h, w, c, boxes, ids, idx):\n", " \"\"\"does not account for empty images meaning no pedestrian in it\"\"\"\n", " # for header, we use minimal length 2, plus width and height\n", " # with A: 4, B: 5, C: width, D: height\n", " A = 4\n", " B = 5\n", " C = w\n", " D = h\n", " # concat id and bboxes\n", " labels = np.hstack((ids.reshape(-1, 1), boxes)).astype(\"float\")\n", " # normalized bboxes (recommended)\n", " labels[:, (1, 3)] /= float(w)\n", " labels[:, (2, 4)] /= float(h)\n", " # flatten\n", " labels = labels.flatten().tolist()\n", " str_idx = [str(idx)]\n", " str_header = [str(x) for x in [A, B, C, D]]\n", " str_labels = [str(x) for x in labels]\n", " str_path = [img_path]\n", " line = \"\\t\".join(str_idx + str_header + str_labels + str_path) + \"\\n\"\n", " return line" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(train_lst, \"w\") as fw:\n", " for idx, line in enumerate(train_data):\n", " all_boxes = []\n", " all_ids = []\n", " h, w = (\n", " line[job_name][\"image_size\"][0][\"height\"],\n", " line[job_name][\"image_size\"][0][\"width\"],\n", " )\n", " for item in line[job_name][\"annotations\"]:\n", " all_boxes.append(\n", " [\n", " item[\"left\"],\n", " item[\"top\"],\n", " item[\"width\"] + item[\"left\"],\n", " item[\"top\"] + item[\"height\"],\n", " ]\n", " )\n", " all_ids.append((item[\"class_id\"]))\n", " line = write_line(\n", " line[\"source-ref\"],\n", " h,\n", " w,\n", " line[job_name][\"image_size\"][0][\"depth\"],\n", " np.array(all_boxes),\n", " np.array(all_ids),\n", " idx,\n", " )\n", " fw.write(line)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(val_lst, \"w\") as fw:\n", " for idx, line in enumerate(validation_data):\n", " all_boxes = []\n", " all_ids = []\n", " h, w = (\n", " line[job_name][\"image_size\"][0][\"height\"],\n", " line[job_name][\"image_size\"][0][\"width\"],\n", " )\n", " for item in line[job_name][\"annotations\"]:\n", " all_boxes.append(\n", " [\n", " item[\"left\"],\n", " item[\"top\"],\n", " item[\"width\"] + item[\"left\"],\n", " item[\"top\"] + item[\"height\"],\n", " ]\n", " )\n", " all_ids.append((item[\"class_id\"]))\n", " line = write_line(\n", " line[\"source-ref\"],\n", " h,\n", " w,\n", " line[job_name][\"image_size\"][0][\"depth\"],\n", " np.array(all_boxes),\n", " np.array(all_ids),\n", " idx,\n", " )\n", " fw.write(line)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with open(test_lst, \"w\") as fw:\n", " for idx, line in enumerate(test_data):\n", " all_boxes = []\n", " all_ids = []\n", " h, w = (\n", " line[job_name][\"image_size\"][0][\"height\"],\n", " line[job_name][\"image_size\"][0][\"width\"],\n", " )\n", " for item in line[job_name][\"annotations\"]:\n", " all_boxes.append(\n", " [\n", " item[\"left\"],\n", " item[\"top\"],\n", " item[\"width\"] + item[\"left\"],\n", " item[\"top\"] + item[\"height\"],\n", " ]\n", " )\n", " all_ids.append((item[\"class_id\"]))\n", " line = write_line(\n", " line[\"source-ref\"],\n", " h,\n", " w,\n", " line[job_name][\"image_size\"][0][\"depth\"],\n", " np.array(all_boxes),\n", " np.array(all_ids),\n", " idx,\n", " )\n", " fw.write(line)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "MXNet provides a tool called [im2rec.py](https://mxnet.apache.org/versions/1.8.0/api/faq/recordio) to convert LST files into the RecordIO file format. We slightly modified the script to read data from S3 instead of your local machine." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.system(\n", " \"python ./ground_truth/im2rec.py ./data/train {} --pack-label --resize 512\".format(\n", " job_name\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.system(\n", " \"python ./ground_truth/im2rec.py ./data/val {} --pack-label --resize 512\".format(\n", " job_name\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.system(\n", " \"python ./ground_truth/im2rec.py ./data/test {} --pack-label --resize 512\".format(\n", " job_name\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the next step we will need to upload the data to S3 to use it for training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.system(\"aws s3 cp ./data s3://{}/ --recursive\".format(BUCKET))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# SageMaker Training\n", "In order to train a customized neural network, we will use [Amazon SageMaker Script Mode](https://github.com/aws-samples/amazon-sagemaker-script-mode), allowing us to bring our own training algorithms while leveraging the simple UI of Amazon SageMaker. \n", "To do so, we need to specify paths where to e.g. store the output of the Training Job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3_output_path = \"s3://{}/model\".format(BUCKET)\n", "s3_code_path = \"s3://{}/customcode\".format(BUCKET)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can specify S3 paths that points towards the S3 folder path containing the .rec, .lst and .idx files" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "training = \"s3://{}/train\".format(BUCKET)\n", "validation = \"s3://{}/val\".format(BUCKET)\n", "test = \"s3://{}/test\".format(BUCKET)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "channels = {\n", " \"train\": TrainingInput(training),\n", " \"val\": TrainingInput(validation),\n", " \"test\": TrainingInput(test),\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you wish to log certain metrics such as mean Average Precision or Losses during the SageMaker Training, you need to add the Regex as a parameter **metric_definitions** when creating the Estimator:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "metric_definitions = [\n", " {\"Name\": \"val:car mAP\", \"Regex\": \"val:mAP=(.*?),\"},\n", " {\"Name\": \"test:car mAP\", \"Regex\": \"test:mAP=(.*?),\"},\n", " {\"Name\": \"BoxCenterLoss\", \"Regex\": \"BoxCenterLoss=(.*?),\"},\n", " {\"Name\": \"ObjLoss\", \"Regex\": \"ObjLoss=(.*?),\"},\n", " {\"Name\": \"BoxScaleLoss\", \"Regex\": \"BoxScaleLoss=(.*?),\"},\n", " {\"Name\": \"ClassLoss\", \"Regex\": \"ClassLoss=(.*?),\"},\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we will need to create the Estimator, providing the entry_point script and specify instance types. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_estimator = MXNet(\n", " entry_point=\"train_yolov3.py\",\n", " role=role,\n", " instance_count=1, # value can be more than 1 for multi node training\n", " instance_type=\"ml.p3.8xlarge\",\n", " framework_version=\"1.8.0\",\n", " metric_definitions=metric_definitions,\n", " output_path=s3_output_path,\n", " code_location=s3_code_path,\n", " py_version=\"py37\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameter optimization \n", "While training neural networks there are lots of parameters that can be optimised to the use case and the custom dataset, which we usually refer to as [automatic model tuning in SageMaker](https://aws.amazon.com/blogs/aws/sagemaker-automatic-model-tuning/)or hyperparameter optimisation. Amazon SageMaker will launch multiple training jobs with a unique combination of hyperparameters and searches the best configuration achieving the highest mean average precision (mAP) on our held-out test data. \n", "\n", "In the following cell we will specify the ranges of the hyperparameters we want to optimize, as well as which objective metric should be monitored while selecting the best hyperparameter configuration. You can select every objective_metric that is logged in your training script." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparameter_ranges = {\n", " \"lr\": ContinuousParameter(0.001, 0.1),\n", " \"wd\": ContinuousParameter(0.0001, 0.001),\n", " \"epochs\": IntegerParameter(5,7),\n", "}\n", "objective_metric_name = \"val:car mAP\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "hpo_tuner = HyperparameterTuner(\n", " model_estimator,\n", " objective_metric_name,\n", " hyperparameter_ranges,\n", " metric_definitions,\n", " max_jobs=2, # maximum jobs that should be ran\n", " max_parallel_jobs=2, # maximum jobs that will be executed in parallel\n", " base_tuning_job_name=\"yolov3\",\n", ")\n", "\n", "hpo_tuner.fit(channels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Compilation with SageMaker Neo\n", "While for training in the cloud we don’t have hard constraints for model environment, we do need to bear in mind the production environment when running inference with trained models: no powerful GPUs and limited storage are common challenges. Amazon has a service called [Amazon SageMaker Neo](https://aws.amazon.com/sagemaker/neo/) that allows you to train once and run anywhere in the cloud and at the edge with optimal performance while reducing the memory footprint of your model! SageMaker Neo optimises models to run up to twice as fast, with no loss in accuracy for a specific hardware type. \n", "\n", "What does Neo do under the hood? Neo converts your model (e.g. json file) into an efficient common format: it simplifies and changes the computational graph of a machine learning model by e.g. combining and reducing the number of transformations. To compile the model with Neo we need to provide information such as the input size of the data, the framework and the target instance." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "best_estimator =hpo_tuner.best_estimator()\n", "compiled_model = best_estimator.compile_model(\n", " target_instance_family=\"ml_c4\",\n", " role=role,\n", " input_shape={\"data\": [1, 3, 512, 512]},\n", " output_path=s3_output_path,\n", " framework=\"mxnet\",\n", " framework_version=\"1.8\",\n", " env={\n", " \"MMS_DEFAULT_RESPONSE_TIMEOUT\": \"500\" }\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Deploy endpoint\n", "Deploying a model requires a few additional lines of code for hosting such as `model_fn` and `transform_fn` in the `train_yolov3.py` script. Bear in mind that the `target_instance_family` of your compilation job and the endpoint `instance_type` should match. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "predictor = compiled_model.deploy(\n", " initial_instance_count=1, instance_type=\"ml.c4.xlarge\", endpoint_name=\"YOLO-DEMO-endpoint\", deserializer=JSONDeserializer(),serializer=JSONSerializer()\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Test predictions\n", "Once the model is deployed with an endpoint, we can test inference with some sample images. To do so we will use the function below, drawing a bounding box around the prediction and displaying object name and the confidence score associated with the prediction." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def visualize_detection(img_file, dets, classes=[], thresh=0.6, data_shape=512):\n", " \"\"\"\n", " visualize detections in one image\n", " Parameters:\n", " ----------\n", " img_file : numpy.array\n", " image, in bgr format\n", " dets : numpy.array\n", " yolo detections, numpy.array([[id, score, x1, y1, x2, y2]...])\n", " each row is one object\n", " classes : tuple or list of str\n", " class names\n", " thresh : float\n", " score threshold\n", " data_shape : integer\n", " size of the image\n", " \"\"\"\n", " import random\n", " import matplotlib.pyplot as plt\n", " import matplotlib.image as mpimg\n", " from matplotlib.patches import Rectangle\n", " img=mpimg.imread(img_file)\n", " plt.imshow(img)\n", " height = img.shape[0]\n", " width = img.shape[1]\n", " colors = dict()\n", " klasses = dets[0][0]\n", " scores = dets[1][0]\n", " bbox = dets[2][0]\n", " for i in range(len(dets[0][0])):\n", " klass = klasses[i][0]\n", " score = scores[i][0]\n", " x0, y0, x1, y1 = bbox[i]\n", " if score < thresh:\n", " continue\n", " cls_id = int(klass)\n", " if cls_id not in colors:\n", " colors[cls_id] = (random.random(), random.random(), random.random())\n", " xmin = int(x0 * width / data_shape)\n", " ymin = int(y0 * height / data_shape)\n", " xmax = int(x1 * width / data_shape)\n", " ymax = int(y1 * height / data_shape)\n", " rect = Rectangle((xmin, ymin), xmax - xmin,\n", " ymax - ymin, fill=False,\n", " edgecolor=colors[cls_id],\n", " linewidth=3.5)\n", " plt.gca().add_patch(rect)\n", " class_name = str(cls_id)\n", " if classes and len(classes) > cls_id:\n", " class_name = classes[cls_id]\n", " plt.gca().text(xmin, ymin-2,\n", " '{:s} {:.3f}'.format(class_name, score),\n", " bbox=dict(facecolor=colors[cls_id], alpha=0.5),\n", " fontsize=12, color='white')\n", " plt.tight_layout(rect=[0, 0, 2, 2])\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can test the model on some example images. Therefore we can download some images from our S3 Bucket, or you can upload your images into the directory `./img/test-images/`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "os.system(\"aws s3 cp s3://{}/{}left0004.png ./img/test-images/left0004.png \".format(s3_bucket_raw_images, s3_folder_raw_images))\n", "os.system(\"aws s3 cp s3://{}/{}left0305.png ./img/test-images/left0305.png \".format(s3_bucket_raw_images, s3_folder_raw_images))\n", "os.system(\"aws s3 cp s3://{}/{}right0234.png ./img/test-images/right0234.png \".format(s3_bucket_raw_images, s3_folder_raw_images))\n", "os.system(\"aws s3 cp s3://{}/{}right0103.png ./img/test-images/right0103.png \".format(s3_bucket_raw_images, s3_folder_raw_images))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "object_categories = [\"car\"] # list of objects to detect\n", "threshold = 0.5 # confidence threshold for the prediction of your model \n", "directory = \"./img/test-images/\" # folder directory of your images you want to use for test predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for file_name in os.listdir(directory):\n", " print(file_name)\n", " test_image = PIL.Image.open(os.path.join(directory, file_name))\n", " test_image = np.asarray(test_image.resize((512, 512))) \n", " endpoint_response = predictor.predict(test_image)\n", " visualize_detection(os.path.join(directory, file_name), endpoint_response, object_categories, threshold)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_mxnet_p36", "language": "python", "name": "conda_mxnet_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }