{ "cells": [ { "cell_type": "markdown", "id": "d80d3f7d", "metadata": {}, "source": [ "# Comparison of Image Classification models and algorithms in Amazon SageMaker JumpStart " ] }, { "cell_type": "markdown", "id": "d422b206", "metadata": {}, "source": [ "---\n", "At times, when you are solving a business problem using machine learning (ML), you might want to use multiple ML algorithms and compare them against each other to see which model gives you the best results on dimensions that you care about - model accuracy, inference time, and training time.\n", "\n", "In this notebook, we demonstrate how you can compare multiple image classification models and algorithms offered by SageMaker JumpStart on dimensions such as model accuracy, inference time, and training time. Models in JumpStart are brought from hubs such as TensorFlow Hub and PyTorch Hub, and training scripts (algorithms) were written separately for each of these frameworks. In this notebook, you can also alter some of the hyper-parameters and examine their effect on the results. \n", "\n", "Image Classification refers to classifying an image to one of the class labels in the training dataset.\n", "\n", "Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html) offers a large suite of ML algorithms. You can use JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker JumpStart API](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart). \n", "\n", "Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "8c59ed72", "metadata": {}, "source": [ "1. [Set Up](#1.-Set-Up)\n", "2. [Specify training and validation data paths](#2.-Specify-training-and-validation-data-paths)\n", "3. [Set hyper-parameters](#3.-Hyper-parameters)\n", "4. [List of models to run](#4.-Specify-models-to-run)\n", "5. [Helper functions](#5.-Helper-functions)\n", "6. [Run all models](#6.-Run-all-models)" ] }, { "cell_type": "markdown", "id": "14244a1a", "metadata": {}, "source": [ "## 1. Set-Up\n", "***\n", "Before executing the notebook, there are some initial steps required for setup. This notebook requires latest version of sagemaker and ipywidgets.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "849f548c", "metadata": {}, "outputs": [], "source": [ "!pip install sagemaker ipywidgets --upgrade --quiet" ] }, { "cell_type": "code", "execution_count": null, "id": "d2d54072", "metadata": {}, "outputs": [], "source": [ "import sagemaker, boto3, json\n", "from sagemaker import get_execution_role\n", "import boto3, uuid\n", "import pandas as pd\n", "\n", "aws_role = get_execution_role()\n", "aws_region = boto3.Session().region_name\n", "sess = sagemaker.Session()\n", "s3 = boto3.client(\"s3\")\n", "\n", "# unique id to connect all runs\n", "# if you run this notebook multiple times, this master id helps you \n", "# save each run's results as a separate csv file\n", "master_uuid = str(uuid.uuid4())\n", "print(\"master id for this run: \", master_uuid)\n", "\n", "# Lists to store results\n", "nameList = []\n", "accList = []\n", "timeList = []" ] }, { "cell_type": "markdown", "id": "43d8ea6c", "metadata": {}, "source": [ "## 2. Specify training and validation data paths\n", "***\n", "Training and validation data needs to be stored in the format specified below\n", "- A directory with as many sub-directories as the number of classes. \n", " - Each sub-directory should have images belonging to that class in .jpg format. \n", " \n", "The input directory should look like below if \n", "the training data contains images from two classes: roses and dandelion.\n", "\n", " input_directory\n", " |--roses\n", " |--abc.jpg\n", " |--def.jpg\n", " |--dandelion\n", " |--ghi.jpg\n", " |--jkl.jpg\n", "\n", "We provide tf_flowers dataset as an example dataset for training and validation. This is only for illutration purpose. When you use this notebook, you need to replace the bucket and prefix references below with your own buckets containing separate datasets for training and validation.\n", "\n", "tf_flower comprises images of five types of flowers. \n", "The dataset has been downloaded from [TensorFlow](https://www.tensorflow.org/datasets/catalog/tf_flowers). \n", "[Apache 2.0 License](https://jumpstart-cache-prod-us-west-2.s3-us-west-2.amazonaws.com/licenses/Apache-License/LICENSE-2.0.txt).\n", "Citation:\n", "\n", "@ONLINE {tfflowers,\n", "author = \"The TensorFlow Team\",\n", "title = \"Flowers\",\n", "month = \"jan\",\n", "year = \"2019\",\n", "url = \"http://download.tensorflow.org/example_images/flower_photos.tgz\" }\n", " source: [TensorFlow Hub](model_url). \n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "b208e373", "metadata": {}, "outputs": [], "source": [ "# Set references to training data\n", "training_data_bucket = f\"jumpstart-cache-prod-{aws_region}\"\n", "training_data_prefix = \"training-datasets/tf_flowers\"\n", "\n", "# Set references to validation data\n", "validation_data_bucket = f\"jumpstart-cache-prod-{aws_region}\"\n", "validation_data_prefix = \"training-datasets/tf_flowers\"" ] }, { "cell_type": "markdown", "id": "3a5e89ff", "metadata": {}, "source": [ "## 3. Hyper-parameters\n", "As explained above, you can modify the three hyper-parameters shown below and examine their effect on the results" ] }, { "cell_type": "code", "execution_count": null, "id": "d53fd3a0", "metadata": {}, "outputs": [], "source": [ "# Setting below hyper-parameters for this run\n", "\n", "# Number of epochs\n", "EPOCHS = \"5\"\n", "\n", "# Learning rate\n", "LR = \"0.001\"\n", "\n", "# Batch size\n", "BATCH_SIZE = \"16\"" ] }, { "cell_type": "markdown", "id": "65437528", "metadata": {}, "source": [ "## 4. Specify models to run" ] }, { "cell_type": "code", "execution_count": null, "id": "37840b13-1736-497d-9f09-d521a7d1806c", "metadata": {}, "outputs": [], "source": [ "from sagemaker.jumpstart.notebook_utils import list_jumpstart_models\n", "\n", "# All available models in JumpStart can be see through this code\n", "# We are showing only the top five models for illustration purpose\n", "\n", "filter_value = \"task == ic\"\n", "ic_models = list_jumpstart_models(filter=filter_value)\n", "\n", "print(\"Total image classification models available in JumpStart: \", len(ic_models))\n", "print()\n", "print(\"Showing five image classification models from JumpStart: \\n\", ic_models[0:5])" ] }, { "cell_type": "code", "execution_count": null, "id": "09f6c086", "metadata": {}, "outputs": [], "source": [ "# We picked arbitraraily four models. You can replace the list below with other models\n", "\n", "# The number of models you add to this list shouldn't exceed the number of training and inference instances\n", "# available to your account in SageMaker, as all these models will be trained and inferred in parallel\n", "models = [\"tensorflow-ic-imagenet-mobilenet-v2-075-224-classification-4\", \n", " \"tensorflow-ic-imagenet-inception-v3-classification-4\", \n", " \"pytorch-ic-googlenet\",\n", " \"pytorch-ic-alexnet\"]" ] }, { "cell_type": "markdown", "id": "2b5cc860", "metadata": {}, "source": [ "## 5. Helper functions" ] }, { "cell_type": "code", "execution_count": null, "id": "4946500d", "metadata": {}, "outputs": [], "source": [ "import os\n", "import time\n", "import random\n", "\n", "# Function to query the endpoint\n", "def query_endpoint(img, endpoint_name):\n", " client = boto3.client('runtime.sagemaker')\n", " response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/x-image', Body=img, Accept='application/json;verbose')\n", " return response\n", "\n", "# Function to parse predicion response\n", "def parse_prediction(query_response):\n", " model_predictions = json.loads(query_response['Body'].read())\n", " predicted_label = model_predictions['predicted_label']\n", " labels = model_predictions['labels']\n", " probabilities = model_predictions['probabilities']\n", " return predicted_label, probabilities, labels \n", "\n", "# Function that returns all files under a given S3 bucket prefix\n", "def listS3Files(bucket, prefix):\n", " file_prefix = []\n", " file_name = []\n", " s3 = boto3.resource('s3')\n", " my_bucket = s3.Bucket(bucket)\n", " for object_summary in my_bucket.objects.filter(Prefix=prefix):\n", " if object_summary.key[-1] != \"/\": # don't append parent directory name\n", " file_prefix.append(object_summary.key)\n", " split = object_summary.key.split(\"/\")\n", " file_name.append(split[-1])\n", " return file_prefix\n", "\n", "# Function to calculate model accuracy\n", "# It will calculate validation accuracy if you supply a validation dataset in the setting above\n", "from sklearn.metrics import accuracy_score\n", "def calcModelAccuracy(endpoint_name, bucket, file_prefixes):\n", " #maximum images to test against\n", " size = 100\n", " if len(file_prefixes)