{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Optimization with an Image Classification Example\n", "1. [Introduction](#Introduction)\n", "2. [Prerequisites and Preprocessing](#Prequisites-and-Preprocessing)\n", "3. [Train the model](#Train-the-model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "***\n", "\n", "Welcome to our model optimization example for image classification. In this demo, we will use the Amazon SageMaker Image Classification algorithm to train on the [caltech-256 dataset](http://www.vision.caltech.edu/Image_Datasets/Caltech256/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prequisites and Preprocessing\n", "\n", "***\n", "\n", "### Setup\n", "\n", "To get started, we need to define a few variables and obtain certain permissions that will be needed later in the example. These are:\n", "* A SageMaker session\n", "* IAM role to give learning, storage & hosting access to your data\n", "* An S3 bucket, a folder & sub folders that will be used to store data and artifacts\n", "* SageMaker's specific Image Classification training image which should not be changed\n", "\n", "We also need to upgrade the [SageMaker SDK for Python](https://sagemaker.readthedocs.io/en/stable/v2.html) to v2.33.0 or greater and restart the kernel." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!~/anaconda3/envs/mxnet_p36/bin/pip install --upgrade sagemaker>=2.33.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker import session, get_execution_role\n", "\n", "role = get_execution_role()\n", "sagemaker_session = session.Session()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# S3 bucket and folders for saving code and model artifacts.\n", "# Change to the name of your Greengrass Component bucket.\n", "bucket = '-gg-components'\n", "folder = 'models/uncompiled'\n", "model_with_custom_code_sub_folder = folder + '/model-with-custom-code'\n", "validation_data_sub_folder = folder + '/validation-data'\n", "training_data_sub_folder = folder + '/training-data'\n", "training_output_sub_folder = folder + '/training-output'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker import session, get_execution_role\n", "from sagemaker.amazon.amazon_estimator import get_image_uri\n", "\n", "# S3 Location to save the model artifact after training\n", "s3_training_output_location = 's3://{}/{}'.format(bucket, training_output_sub_folder)\n", "\n", "# S3 Location to save your custom code in tar.gz format\n", "s3_model_with_custom_code_location = 's3://{}/{}'.format(bucket, model_with_custom_code_sub_folder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.image_uris import retrieve\n", "aws_region = sagemaker_session.boto_region_name\n", "training_image = retrieve(framework='image-classification', region=aws_region, image_scope='training')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data preparation\n", "\n", "In this demo, we are using [Caltech-256](http://www.vision.caltech.edu/Image_Datasets/Caltech256/) dataset, pre-converted into `RecordIO` format using MXNet's [im2rec](https://mxnet.apache.org/versions/1.7/api/faq/recordio) tool. Caltech-256 dataset contains 30608 images of 256 objects. For the training and validation data, the splitting scheme followed is governed by this [MXNet example](https://github.com/apache/incubator-mxnet/blob/8ecdc49cf99ccec40b1e342db1ac6791aa97865d/example/image-classification/data/caltech256.sh). The example randomly selects 60 images per class for training, and uses the remaining data for validation. It takes around 50 seconds to convert the entire Caltech-256 dataset (~1.2GB) into `RecordIO` format on a p2.xlarge instance. SageMaker's training algorithm takes `RecordIO` files as input. For this demo, we will download the `RecordIO` files and upload it to S3. We then initialize the 256 object categories as well to a variable." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os \n", "import urllib.request\n", "\n", "def download(url):\n", " filename = url.split(\"/\")[-1]\n", " if not os.path.exists(filename):\n", " urllib.request.urlretrieve(url, filename)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Dowload caltech-256 data files from MXNet's website\n", "download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec')\n", "download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec')\n", "\n", "# Upload the file to S3\n", "s3_training_data_location = sagemaker_session.upload_data('caltech-256-60-train.rec', bucket, training_data_sub_folder)\n", "s3_validation_data_location = sagemaker_session.upload_data('caltech-256-60-val.rec', bucket, validation_data_sub_folder)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class_labels = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat',\n", " 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101',\n", " 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101',\n", " 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire',\n", " 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks',\n", " 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor',\n", " 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe',\n", " 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell',\n", " 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern',\n", " 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight',\n", " 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan',\n", " 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose',\n", " 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock',\n", " 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus',\n", " 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass',\n", " 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris',\n", " 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder',\n", " 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning',\n", " 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope',\n", " 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels',\n", " 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder',\n", " 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card',\n", " 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator',\n", " 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus',\n", " 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', \n", " 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile',\n", " 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass',\n", " 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan',\n", " 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee',\n", " 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato',\n", " 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops',\n", " 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn',\n", " 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask',\n", " 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101',\n", " 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model\n", "\n", "***\n", "\n", "Now that we are done with all the setup that is needed, we are ready to train our object detector. To begin, let us create a ``sagemaker.estimator.Estimator`` object. This estimator is required to launch the training job.\n", "\n", "We specify the following parameters while creating the estimator:\n", "\n", "* ``image_uri``: This is set to the training_image uri we defined previously. Once set, this image will be used later while running the training job.\n", "* ``role``: This is the IAM role which we defined previously.\n", "* ``instance_count``: This is the number of instances on which to run the training. When the number of instances is greater than one, then the image classification algorithm will run in distributed settings. \n", "* ``instance_type``: This indicates the type of machine on which to run the training. For this example we will use `ml.p3.8xlarge`.\n", "* ``volume_size``: This is the size in GB of the EBS volume to use for storing input data during training. Must be large enough to store training data as File Mode is used.\n", "* ``max_run``: This is the timeout value in seconds for training. After this amount of time SageMaker terminates the job regardless of its current status.\n", "* ``input_mode``: This is set to `File` in this example. SageMaker copies the training dataset from the S3 location to a local directory.\n", "* ``output_path``: This is the S3 path in which the training output is stored. We are assigning it to `s3_training_output_location` defined previously.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ic_estimator = sagemaker.estimator.Estimator(image_uri=training_image,\n", " role=role,\n", " instance_count=1,\n", " instance_type='ml.p3.8xlarge',\n", " volume_size = 50,\n", " max_run = 360000,\n", " input_mode= 'File',\n", " output_path=s3_training_output_location,\n", " base_job_name='img-classification-training'\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Following are certain hyperparameters that are specific to the algorithm which are also set:\n", "\n", "* ``num_layers``: The number of layers (depth) for the network. We use 18 in this samples but other values such as 50, 152 can be used.\n", "* ``image_shape``: The input image dimensions,'num_channels, height, width', for the network. It should be no larger than the actual image size. The number of channels should be same as the actual image.\n", "* ``num_classes``: This is the number of output classes for the new dataset. Imagenet was trained with 1000 output classes but the number of output classes can be changed for fine-tuning. For caltech, we use 257 because it has 256 object categories + 1 clutter class.\n", "* ``num_training_samples``: This is the total number of training samples. It is set to 15240 for caltech dataset with the current split.\n", "* ``mini_batch_size``: The number of training samples used for each mini batch. In distributed training, the number of training samples used per batch will be N * mini_batch_size where N is the number of hosts on which training is run.\n", "* ``epochs``: Number of training epochs.\n", "* ``learning_rate``: Learning rate for training.\n", "* ``top_k``: Report the top-k accuracy during training.\n", "* ``precision_dtype``: Training datatype precision (default: float32). If set to 'float16', the training will be done in mixed_precision mode and will be faster than float32 mode." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ic_estimator.set_hyperparameters(num_layers=18,\n", " image_shape = \"3,224,224\",\n", " num_classes=257,\n", " num_training_samples=15420,\n", " mini_batch_size=128,\n", " epochs=5,\n", " learning_rate=0.01,\n", " top_k=2,\n", " use_pretrained_model=1,\n", " precision_dtype='float32')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we setup the input ``data_channels`` to be used later for training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data = sagemaker.inputs.TrainingInput(s3_training_data_location,\n", " content_type='application/x-recordio',\n", " s3_data_type='S3Prefix')\n", "\n", "validation_data = sagemaker.inputs.TrainingInput(s3_validation_data_location,\n", " content_type='application/x-recordio',\n", " s3_data_type='S3Prefix')\n", "\n", "data_channels = {'train': train_data, 'validation': validation_data}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After we've created the estimator object, we can train the model using ``fit()`` API" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "ic_estimator.fit(inputs=data_channels, logs=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After the training job completes, your trained model will be stored in the bucket specified above. This should be in your Greengrass Components bucket models/uncompiled folder. Check in S3 that you can see the output of the training job." ] } ], "metadata": { "kernelspec": { "display_name": "conda_mxnet_p36", "language": "python", "name": "conda_mxnet_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" }, "notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 4 }