# End-to-End Multiclass Image Classification Example in C#

1. [Introduction](#Introduction)
2. [Prerequisites](#Prerequisites)
  1. [Lifecycle Configuration](#Lifecycle-Configuration)
  2. [Required Packages](#Required-Packages)
  3. [Data Preparation](#Data-Preparation)
3. [Train the ResNet Model](#Train-the-ResNet-Model)
4. [Deploy the Model](#Deploy-the-Model)
5. [Use the Model to perform Inferences](#Use-the-Model-to-perform-Inferences)
  1. [Create endpoint configuration](#Create-endpoint-configuration) 
  2. [Create endpoint](#Create-endpoint) 
  3. [Perform Inferences](#Perform-Inferences)
  4. [Clean up](#Clean-up)

## Introduction

Welcome to our end-to-end example of distributed image classification algorithm. In this demo, we will use the Amazon sagemaker image classification algorithm to train on the [caltech-256 dataset](http://www.vision.caltech.edu/Image_Datasets/Caltech256/). 

This example builds upon the Python based example found in the [AWS Samples GitHub Repo](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/introduction_to_amazon_algorithms/imageclassification_caltech/Image-classification-fulltraining.ipynb).

## Prerequisites

### Lifecycle Configuration

In order to enable your SageMaker Notebook environments to run this Notebook with the C# Kernel, you need to have already created and associated a customized [Lifecycle Configuration](https://docs.aws.amazon.com/sagemaker/latest/dg/notebook-lifecycle-config.html) with this Notebook Instance using the Configuration Script provided in this repo.

### Required Packages

Start by including all of the required NuGet Packages for the Notebook
- SageMaker SDK
- AWS S3 SDK
- NewtonSoft Json

In [None]:
#r "nuget:AWSSDK.SageMaker, 3.3.112.3"
#r "nuget:AWSSDK.SageMakerRuntime, 3.3.101.49"
#r "nuget:AWSSDK.S3, 3.3.110.45"
#r "nuget:Newtonsoft.Json, 12.0.3"

After the above packages have been installed, include the relevant Namespaces into the application scope 

In [None]:
using Amazon.SageMaker;
using Amazon.SageMaker.Model;

using Amazon.SageMakerRuntime;
using Amazon.SageMakerRuntime.Model;

using Amazon.S3;
using Amazon.S3.Model;

using Newtonsoft.Json;

using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading.Tasks;

Create the various AWS SDK Client Objects that will be used to make API calls to the AWS Services, and the WebClient that will be used to download files

In [None]:
static AmazonS3Client s3Client = new AmazonS3Client();
AmazonSageMakerClient smClient = new AmazonSageMakerClient();
AmazonSageMakerRuntimeClient smrClient = new AmazonSageMakerRuntimeClient();
WebClient webClient = new WebClient();

### Data Preparation

Download the required Training Data and Validation Data

In [None]:
webClient.DownloadFile("http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec", "caltech-256-60-train.rec");
webClient.DownloadFile("http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec", "caltech-256-60-val.rec");

Setup the proper S3 paths/locations where the training/validation data will be stored

In [None]:
String bucketName = "dotnet-sagemaker-bucket";
String trainKey = "image-classification-full-training/train";
String validationKey = "image-classification-full-training/validation";
String s3Train = String.Format("s3://{0}/{1}/",bucketName,trainKey);
String s3Validation = String.Format("s3://{0}/{1}/",bucketName,validationKey);
String trainFile = String.Format("{0}/caltech-256-60-train.rec",trainKey);
String validationFile = String.Format("{0}/caltech-256-60-val.rec",validationKey);
Console.WriteLine(s3Train);
Console.WriteLine(s3Validation);
Console.WriteLine(trainFile);
Console.WriteLine(validationFile);

Define the function that will be used to upload files into the S3 bucket

In [None]:
static async Task UploadToS3(string s3Bucket, string s3Key, string filePath)
        {
            try
            {
                var putRequest = new PutObjectRequest
                {
                    BucketName = s3Bucket,
                    Key = s3Key,
                    FilePath = filePath,
                    ContentType = "text/plain"
                };
                PutObjectResponse response1 = await s3Client.PutObjectAsync(putRequest);
            }
            catch (AmazonS3Exception e)
            {
                Console.WriteLine(
                        "Error encountered ***. Message:'{0}' when writing an object"
                        , e.Message);
            }
            catch (Exception e)
            {
                Console.WriteLine(
                    "Unknown encountered on server. Message:'{0}' when writing an object"
                    , e.Message);
            }
        }

Copy all of the downloaded Training and Validation data into the S3 bucket

In [None]:
UploadToS3(bucketName,trainFile,@"./caltech-256-60-train.rec").Wait();
UploadToS3(bucketName,validationFile,@"./caltech-256-60-val.rec").Wait();

## Train the ResNet Model

In this demo, we are using [Caltech-256](http://www.vision.caltech.edu/Image_Datasets/Caltech256/) dataset, which contains 30608 images of 256 objects. For the training and validation data, we follow the splitting scheme in this MXNet [example](https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/data/caltech256.sh). In particular, it randomly selects 60 images per class for training, and uses the remaining data for validation. The algorithm takes `RecordIO` file as input. The user can also provide the image files as input, which will be converted into `RecordIO` format using MXNet's [im2rec](https://mxnet.incubator.apache.org/how_to/recordio.html?highlight=im2rec) tool. It takes around 50 seconds to converted the entire Caltech-256 dataset (~1.2GB) on a p2.xlarge instance. However, for this demo, we will use record io format.

Once we have the data available in the correct format for training, the next step is to actually train the model using the data. After setting training parameters, we kick off training, and poll for status until training is completed.  

An IAM role is required for the SageMaker service to use when it is running the training job and creating the Model. Extract this role from the current running Notebook instance

In [None]:
DescribeNotebookInstanceRequest dniReq = new DescribeNotebookInstanceRequest() {
    NotebookInstanceName = "dotNetV3-1"
};
DescribeNotebookInstanceResponse dniResp = await smClient.DescribeNotebookInstanceAsync(dniReq);
Console.WriteLine(dniResp.RoleArn);

Create a new Training Job Request object with all required parameters. To see a detailed description of what the parameters mean, refer to the [Python Sample](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/introduction_to_amazon_algorithms/imageclassification_caltech/Image-classification-fulltraining.ipynb).

In [None]:
string jobName = String.Format("DEMO-imageclassification-{0}",DateTime.Now.ToString("yyyy-MM-dd-hh-mmss"));

CreateTrainingJobRequest ctrRequest = new CreateTrainingJobRequest(){
    AlgorithmSpecification = new AlgorithmSpecification(){
        TrainingImage = "433757028032.dkr.ecr.us-west-2.amazonaws.com/image-classification:1",
        TrainingInputMode = "File"  
    },
    RoleArn = dniResp.RoleArn,
    OutputDataConfig = new OutputDataConfig(){
        S3OutputPath = String.Format(@"s3://{0}/{1}/output",bucketName,jobName)
    },
    ResourceConfig = new ResourceConfig(){
        InstanceCount = 1,
        InstanceType = Amazon.SageMaker.TrainingInstanceType.MlP2Xlarge,
        VolumeSizeInGB = 50
    },
    TrainingJobName = jobName,
    HyperParameters = new Dictionary<string,string>() {
        {"image_shape","3,224,224"},
        {"num_layers","18"},
        {"num_training_samples","15420"},
        {"num_classes","257"},
        {"mini_batch_size","64"},
        {"epochs","10"},
        {"learning_rate","0.01"}
    },
    StoppingCondition = new StoppingCondition(){
        MaxRuntimeInSeconds = 360000
    },
    InputDataConfig = new List<Amazon.SageMaker.Model.Channel>(){
        new Amazon.SageMaker.Model.Channel() {
            ChannelName = "train",
            ContentType = "application/x-recordio",
            CompressionType = Amazon.SageMaker.CompressionType.None,
            DataSource = new Amazon.SageMaker.Model.DataSource(){
                S3DataSource = new Amazon.SageMaker.Model.S3DataSource(){
                    S3DataType = Amazon.SageMaker.S3DataType.S3Prefix,
                    S3Uri = s3Train,
                    S3DataDistributionType = Amazon.SageMaker.S3DataDistribution.FullyReplicated
                }
            }
        },
        new Amazon.SageMaker.Model.Channel(){
            ChannelName = "validation",
            ContentType = "application/x-recordio",
            CompressionType = Amazon.SageMaker.CompressionType.None,
            DataSource = new Amazon.SageMaker.Model.DataSource(){
                S3DataSource = new Amazon.SageMaker.Model.S3DataSource(){
                    S3DataType = Amazon.SageMaker.S3DataType.S3Prefix,
                    S3Uri = s3Validation,
                    S3DataDistributionType = Amazon.SageMaker.S3DataDistribution.FullyReplicated
                }
            }
        }
    }
};

Submit the request for the Training job to be created

In [None]:
CreateTrainingJobResponse ctrResponse = await smClient.CreateTrainingJobAsync(ctrRequest);
Console.WriteLine(ctrResponse.TrainingJobArn);

Poll the status of the submitted Training job - Run the next block a few times until the job has been completed

In [None]:
DescribeTrainingJobRequest tjReq = new DescribeTrainingJobRequest(){
    TrainingJobName = jobName
};
DescribeTrainingJobResponse tjResp = await smClient.DescribeTrainingJobAsync(tjReq);
Console.WriteLine(tjResp.TrainingJobStatus);

## Deploy the Model

Once the Training job above has been completed, it is time to make use of the trained model - first we have to deploy the Model. Create the request object with all required parameters and make API call to generate the Model.

In [None]:
string modelName = String.Format("DEMO-full-image-classification-model-{0}",DateTime.Now.ToString("yyyy-MM-dd-hh-mmss"));
Console.WriteLine(modelName);

CreateModelRequest modelRequest = new CreateModelRequest(){
    ModelName = modelName,
    ExecutionRoleArn = dniResp.RoleArn,
    PrimaryContainer = new ContainerDefinition(){
        Image = "433757028032.dkr.ecr.us-west-2.amazonaws.com/image-classification:latest",
        ModelDataUrl = tjResp.ModelArtifacts.S3ModelArtifacts
    }
};

CreateModelResponse modelResponse = await smClient.CreateModelAsync(modelRequest);
Console.WriteLine(modelResponse.ModelArn);

## Use the Model to perform Inferences

We now host the model with an endpoint and perform realtime inference.

This section involves several steps,

1. [Create endpoint configuration](#Create-endpoint-configuration) - Create a configuration defining an endpoint.
2. [Create endpoint](#Create-endpoint) - Use the configuration to create an inference endpoint.
3. [Perform Inferences](#Perform-Inferences) - Perform inference on some input data using the endpoint.
4. [Clean up](#Clean-up) - Delete the endpoint and model

### Create endpoint configuration

In [None]:
string epConfName = String.Format("{0}-EndPointConfig",jobName);
Console.WriteLine(epConfName);

CreateEndpointConfigRequest epConfReq = new CreateEndpointConfigRequest(){
    EndpointConfigName = epConfName,
    ProductionVariants = new List<ProductionVariant>(){
        new ProductionVariant() {
            InstanceType = Amazon.SageMaker.ProductionVariantInstanceType.MlP28xlarge,
            InitialInstanceCount = 1,
            ModelName = modelName,
            VariantName = "AllTraffic"
        }
    }
};

CreateEndpointConfigResponse epConfResp = await smClient.CreateEndpointConfigAsync(epConfReq);
Console.WriteLine(epConfResp.EndpointConfigArn);

### Create endpoint 

In [None]:
string epName = String.Format("{0}-EndPoint",jobName);
Console.WriteLine(epName);

CreateEndpointRequest epReq = new CreateEndpointRequest(){
    EndpointName = epName,
    EndpointConfigName = epConfName
};

CreateEndpointResponse epResp = await smClient.CreateEndpointAsync(epReq);
Console.WriteLine(epResp.EndpointArn);

Poll the status Endpoint creation - Run the next block a few times until Endpoint status shows 'InService'.

In [None]:
DescribeEndpointRequest deReq = new DescribeEndpointRequest(){
    EndpointName = epName
};
DescribeEndpointResponse deResp = await smClient.DescribeEndpointAsync(deReq);
Console.WriteLine(deResp.EndpointStatus);

### Perform Inferences

Once the Endpoint is InService, it is ready to be invoked to retrieve Inferences/Predictions. Download 2 sample images that will be used for this purpose

In [None]:
webClient.DownloadFile("http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/008.bathtub/008_0007.jpg", "008_0007.jpg");
webClient.DownloadFile("http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/008.bathtub/008_0009.jpg", "008_0009.jpg");

For each of the above 2 downloaded files, the next step is to load the file into memory and Invoke the endpoint with the file as payload to retrieve a Prediction about the image. The response will include an array of 'Probability' percentages for each of the known categories for the Training data set.

In [None]:
//Load the known categories into memory
String[] categoriesArray = new String[]{"ak47", "american-flag", "backpack", "baseball-bat", "baseball-glove", "basketball-hoop", "bat", "bathtub", "bear", "beer-mug", "billiards", "binoculars", "birdbath", "blimp", "bonsai-101", "boom-box", "bowling-ball", "bowling-pin", "boxing-glove", "brain-101", "breadmaker", "buddha-101", "bulldozer", "butterfly", "cactus", "cake", "calculator", "camel", "cannon", "canoe", "car-tire", "cartman", "cd", "centipede", "cereal-box", "chandelier-101", "chess-board", "chimp", "chopsticks", "cockroach", "coffee-mug", "coffin", "coin", "comet", "computer-keyboard", "computer-monitor", "computer-mouse", "conch", "cormorant", "covered-wagon", "cowboy-hat", "crab-101", "desk-globe", "diamond-ring", "dice", "dog", "dolphin-101", "doorknob", "drinking-straw", "duck", "dumb-bell", "eiffel-tower", "electric-guitar-101", "elephant-101", "elk", "ewer-101", "eyeglasses", "fern", "fighter-jet", "fire-extinguisher", "fire-hydrant", "fire-truck", "fireworks", "flashlight", "floppy-disk", "football-helmet", "french-horn", "fried-egg", "frisbee", "frog", "frying-pan", "galaxy", "gas-pump", "giraffe", "goat", "golden-gate-bridge", "goldfish", "golf-ball", "goose", "gorilla", "grand-piano-101", "grapes", "grasshopper", "guitar-pick", "hamburger", "hammock", "harmonica", "harp", "harpsichord", "hawksbill-101", "head-phones", "helicopter-101", "hibiscus", "homer-simpson", "horse", "horseshoe-crab", "hot-air-balloon", "hot-dog", "hot-tub", "hourglass", "house-fly", "human-skeleton", "hummingbird", "ibis-101", "ice-cream-cone", "iguana", "ipod", "iris", "jesus-christ", "joy-stick", "kangaroo-101", "kayak", "ketch-101", "killer-whale", "knife", "ladder", "laptop-101", "lathe", "leopards-101", "license-plate", "lightbulb", "light-house", "lightning", "llama-101", "mailbox", "mandolin", "mars", "mattress", "megaphone", "menorah-101", "microscope", "microwave", "minaret", "minotaur", "motorbikes-101", "mountain-bike", "mushroom", "mussels", "necktie", "octopus", "ostrich", "owl", "palm-pilot", "palm-tree", "paperclip", "paper-shredder", "pci-card", "penguin", "people", "pez-dispenser", "photocopier", "picnic-table", "playing-card", "porcupine", "pram", "praying-mantis", "pyramid", "raccoon", "radio-telescope", "rainbow", "refrigerator", "revolver-101", "rifle", "rotary-phone", "roulette-wheel", "saddle", "saturn", "school-bus", "scorpion-101", "screwdriver", "segway", "self-propelled-lawn-mower", "sextant", "sheet-music", "skateboard", "skunk", "skyscraper", "smokestack", "snail", "snake", "sneaker", "snowmobile", "soccer-ball", "socks", "soda-can", "spaghetti", "speed-boat", "spider", "spoon", "stained-glass", "starfish-101", "steering-wheel", "stirrups", "sunflower-101", "superman", "sushi", "swan", "swiss-army-knife", "sword", "syringe", "tambourine", "teapot", "teddy-bear", "teepee", "telephone-box", "tennis-ball", "tennis-court", "tennis-racket", "theodolite", "toaster", "tomato", "tombstone", "top-hat", "touring-bike", "tower-pisa", "traffic-light", "treadmill", "triceratops", "tricycle", "trilobite-101", "tripod", "t-shirt", "tuning-fork", "tweezer", "umbrella-101", "unicorn", "vcr", "video-projector", "washing-machine", "watch-101", "waterfall", "watermelon", "welding-mask", "wheelbarrow", "windmill", "wine-bottle", "xylophone", "yarmulke", "yo-yo", "zebra", "airplanes-101", "car-side-101", "faces-easy-101", "greyhound", "tennis-shoes", "toad", "clutter"};
List<String> categories = categoriesArray.ToList();

Invoke the Endpoint to get an Inference for the first image

In [None]:
MemoryStream dataStream = new MemoryStream(File.ReadAllBytes(@"./008_0007.jpg"));
InvokeEndpointRequest invReq = new InvokeEndpointRequest(){
    EndpointName = epName,
    ContentType = "application/x-image",
    Body = dataStream
};
InvokeEndpointResponse invResp = await smrClient.InvokeEndpointAsync(invReq);

//Read the response stream back into a astring so it can be reviewed
StreamReader sr = new StreamReader(invResp.Body);
String responseBody = sr.ReadToEnd();

//Load the values into a List so they can be more easily searched
List<Decimal> probabilities = JsonConvert.DeserializeObject<List<Decimal>>(responseBody);

//Determine which category returned the highest Probability match and print it's value and Index 
var indexAtMax = probabilities.IndexOf(probabilities.Max());
Console.WriteLine(String.Format("Index of Max Probability: {0}",indexAtMax));
Console.WriteLine(String.Format("Value of Max Probability: {0}",probabilities[indexAtMax]));

//Print which Category name matches with the Image
Console.WriteLine(String.Format("Category of Image : {0}",categories[indexAtMax]));


Invoke the Endpoint to get an Inference for the first image

In [None]:
MemoryStream dataStream2 = new MemoryStream(File.ReadAllBytes(@"./008_0009.jpg"));
InvokeEndpointRequest invReq2 = new InvokeEndpointRequest(){
    EndpointName = epName,
    ContentType = "application/x-image",
    Body = dataStream2
};
InvokeEndpointResponse invResp2 = await smrClient.InvokeEndpointAsync(invReq2);

//Read the response stream back into a astring so it can be reviewed
StreamReader sr2 = new StreamReader(invResp2.Body);
String responseBody2 = sr2.ReadToEnd();

//Load the values into a List so they can be more easily searched
List<Decimal> probabilities2 = JsonConvert.DeserializeObject<List<Decimal>>(responseBody2);

//Determine which category returned the highest Probability match and print it's value and Index 
var indexAtMax2 = probabilities2.IndexOf(probabilities2.Max());
Console.WriteLine(String.Format("Index of Max Probability: {0}",indexAtMax2));
Console.WriteLine(String.Format("Value of Max Probability: {0}",probabilities2[indexAtMax2]));

//Print which Category name matches with the Image
Console.WriteLine(String.Format("Category of Image : {0}",categories[indexAtMax2]));


### Clean up

Delete the endpoint

In [None]:
DeleteEndpointRequest delReq = new DeleteEndpointRequest(){
    EndpointName = epName
};
DeleteEndpointResponse delResp = await smClient.DeleteEndpointAsync(delReq);