# Training and hosting SageMaker Models using the Apache MXNet Module API

In this example, we train a simple neural network using the Apache MXNet [Module API](https://mxnet.incubator.apache.org/api/python/module.html) and the MNIST dataset. The MNIST dataset is widely used for handwritten digit classification, and consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). The task at hand is to train a model using the 60,000 training images and subsequently test its classification accuracy on the 10,000 test images.

### Setup

First we need to define a few variables that will be needed later in the example.

In [None]:
from sagemaker import get_execution_role
import json
import boto3
import time
import os
import time
import tarfile
from botocore.exceptions import ClientError
cf = boto3.client('cloudformation')
s3 = boto3.client('s3')
sns = boto3.client('sns')
step = boto3.client('stepfunctions')
sagemaker = boto3.client('sagemaker-runtime')
ssm=boto3.client('ssm')
cf = boto3.client('cloudformation')

with open('../config.json') as json_file: 
 config = json.load(json_file)
StackName=config["StackName"]

result=cf.describe_stacks(
 StackName=StackName
)
outputs={}
for output in result['Stacks'][0]['Outputs']:
 outputs[output['OutputKey']]=output['OutputValue']

with tarfile.open("script.tar.gz", "w:gz") as tar:
 tar.add(os.getcwd(),arcname="")

s3.upload_file("script.tar.gz",outputs["CodeBucket"],"script.tar.gz")
#IAM execution role that gives SageMaker access to resources in your AWS account.
#We can use the SageMaker Python SDK to get the role from our notebook environment. 
role = get_execution_role()

We need to make sure the Sagebuild template is configured correctly for MXNET. the following code will set the stack configuration

In [None]:
params=result["Stacks"][0]["Parameters"]
for n,i in enumerate(params):
 if(i["ParameterKey"]=="ConfigFramework"):
 i["ParameterValue"]="MXNET" 

try:
 cf.update_stack(
 StackName=StackName,
 UsePreviousTemplate=True,
 Parameters=params,
 Capabilities=[
 'CAPABILITY_NAMED_IAM',
 ]
 )
 waiter = cf.get_waiter('stack_update_complete')
 print("Waiting for stack update")
 waiter.wait(
 StackName=StackName,
 WaiterConfig={
 'Delay':10,
 'MaxAttempts':600
 }
 )

except ClientError as e:
 if(e.response["Error"]["Message"]=="No updates are to be performed."):
 pass
 else:
 raise e
print("stack ready!")

### Download Data to bucket

In [None]:
dataBucket=outputs["DataBucket"]
!aws s3 cp s3://sagemaker-sample-data-us-east-1/mxnet/mnist/train s3://$dataBucket/train/mnist --recursive
!aws s3 cp s3://sagemaker-sample-data-us-east-1/mxnet/mnist/test s3://$dataBucket/test/mnist --recursive

### Update SageBuild Parameters

In [None]:
store=outputs["ParameterStore"]
result=ssm.get_parameter(Name=store)

params=json.loads(result["Parameter"]["Value"])
params["hyperparameters"]={'learning_rate': 0.1}
params["trainentrypoint"]="mnist.py"
params["trainsourcefile"]="s3://{}/script.tar.gz".format(outputs["CodeBucket"])
params["hostentrypoint"]="mnist.py"
params["hostsourcefile"]="s3://{}/script.tar.gz".format(outputs["CodeBucket"])
params["pyversion"]="py2"
params["channels"]={
 "train":{
 "path":"train/mnist"
 },
 "test":{
 "path":"test/mnist"
 }
}

ssm.put_parameter(
 Name=store,
 Type="String",
 Overwrite=True,
 Value=json.dumps(params)
)

## The training script

The ``mnist.py`` script provides all the code we need for training and hosting a SageMaker model. The script we will use is adaptated from Apache MXNet [MNIST tutorial (https://mxnet.incubator.apache.org/tutorials/python/mnist.html).

### Start Train/Deploy pipeline

In [None]:
%%time
result=sns.publish(
 TopicArn=outputs['LaunchTopic'],
 Message="{}" #message is not important, just publishing to topic starts build
)
print(result)
time.sleep(5)
#list all executions for our StateMachine to get our current running one
result=step.list_executions(
 stateMachineArn=outputs['StateMachine'],
 statusFilter="RUNNING"
)['executions']

if len(result) > 0:
 response = step.describe_execution(
 executionArn=result[0]['executionArn']
 )
 status=response['status']
 print(status,response['name'])
 #poll status till execution finishes
 while status == "RUNNING":
 print('.',end="")
 time.sleep(5)
 status=step.describe_execution(executionArn=result[0]['executionArn'])['status']
 print()
 print(status)
else:
 print("no running tasks")

### Making an inference request

The request handling behavior of the Endpoint is determined by the ``mnist.py`` script. In this case, the script doesn't include any request handling functions, so the Endpoint will use the default handlers provided by SageMaker. These default handlers allow us to perform inference on input data encoded as a multi-dimensional JSON array.

To see inference in action, draw a digit in the image box below. The pixel data from your drawing will be loaded into a ``data`` variable in this notebook. 

*Note: after drawing the image, you'll need to move to the next notebook cell.*

In [None]:
from IPython.display import HTML
HTML(open("input.html").read())

Now we can classify the handwritten digit:

In [None]:
result=sagemaker.invoke_endpoint(
 EndpointName=outputs["SageMakerEndpoint"],
 Body=json.dumps(data), 
 ContentType="application/json",
 Accept="application/json"
)

response =json.loads(result['Body'].read().decode('utf-8'))
labeled_predictions = list(zip(range(10), response[0]))
print('Labeled predictions: ')
print(labeled_predictions)

labeled_predictions.sort(key=lambda label_and_prob: 1.0 - label_and_prob[1])
print('Most likely answer: {}'.format(labeled_predictions[0]))