# MNIST distributed training 

This tutorial focuses on how to create a convolutional neural network model to train the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) using **TensorFlow distributed training**.



### Set up the environment

In [None]:
import sagemaker as sagemakerSDK
from sagemaker import get_execution_role
import json
import boto3
import time
import os
import time
import tarfile
from botocore.exceptions import ClientError
cf = boto3.client('cloudformation')
s3 = boto3.client('s3')
sns = boto3.client('sns')
step = boto3.client('stepfunctions')
sagemaker = boto3.client('sagemaker-runtime')
ssm=boto3.client('ssm')
cf = boto3.client('cloudformation')

with open('../config.json') as json_file: 
 config = json.load(json_file)
StackName=config["StackName"]

result=cf.describe_stacks(
 StackName=StackName
)
outputs={}
for output in result['Stacks'][0]['Outputs']:
 outputs[output['OutputKey']]=output['OutputValue']

with tarfile.open("script.tar.gz", "w:gz") as tar:
 tar.add(os.getcwd(),arcname="")

s3.upload_file("script.tar.gz",outputs["CodeBucket"],"script.tar.gz")
#IAM execution role that gives SageMaker access to resources in your AWS account.
#We can use the SageMaker Python SDK to get the role from our notebook environment. 
role = get_execution_role()

We need to make sure the Sagebuild template is configured correctly for TensorFlow. the following code will set the stack configuration

In [None]:
params=result["Stacks"][0]["Parameters"]
for n,i in enumerate(params):
 if(i["ParameterKey"]=="ConfigFramework"):
 i["ParameterValue"]="TENSORFLOW" 

try:
 cf.update_stack(
 StackName=StackName,
 UsePreviousTemplate=True,
 Parameters=params,
 Capabilities=[
 'CAPABILITY_NAMED_IAM',
 ]
 )
 waiter = cf.get_waiter('stack_update_complete')
 print("Waiting for stack update")
 waiter.wait(
 StackName=StackName,
 WaiterConfig={
 'Delay':10,
 'MaxAttempts':600
 }
 )

except ClientError as e:
 if(e.response["Error"]["Message"]=="No updates are to be performed."):
 pass
 else:
 raise e
print("stack ready!")

### Download the MNIST dataset

In [None]:
import utils
from tensorflow.contrib.learn.python.learn.datasets import mnist
import tensorflow as tf

data_sets = mnist.read_data_sets('data', dtype=tf.uint8, reshape=False, validation_size=5000)

utils.convert_to(data_sets.train, 'train', 'data')
utils.convert_to(data_sets.validation, 'validation', 'data')
utils.convert_to(data_sets.test, 'test', 'data')

# Construct a script for distributed training 
Here is the full code for the network model:

The script here is and adaptation of the [TensorFlow MNIST example](https://github.com/tensorflow/models/tree/master/official/mnist). It provides a ```model_fn(features, labels, mode)```, which is used for training, evaluation and inference. 

## A regular ```model_fn```

A regular **```model_fn```** follows the pattern:
1. [defines a neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L96)
- [applies the ```features``` in the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L178)
- [if the ```mode``` is ```PREDICT```, returns the output from the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L186)
- [calculates the loss function comparing the output with the ```labels```](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L188)
- [creates an optimizer and minimizes the loss function to improve the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L193)
- [returns the output, optimizer and loss function](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L205)

## Writing a ```model_fn``` for distributed training
When distributed training happens, the same neural network will be sent to the multiple training instances. Each instance will predict a batch of the dataset, calculate loss and minimize the optimizer. One entire loop of this process is called **training step**.

### Syncronizing training steps
A [global step](https://www.tensorflow.org/api_docs/python/tf/train/global_step) is a global variable shared between the instances. It's necessary for distributed training, so the optimizer will keep track of the number of **training steps** between runs: 

```python
train_op = optimizer.minimize(loss, tf.train.get_or_create_global_step())
```

That is the only required change for distributed training!

In [None]:
!cat 'mnist.py'

## Set Parameters for SageBuild

In [None]:
store=outputs["ParameterStore"]
result=ssm.get_parameter(Name=store)

params=json.loads(result["Parameter"]["Value"])
params["hyperparameters"]={
 'training_steps':1000, 
 'evaluation_steps':100,
 'sagemaker_requirements':""
}
params["trainentrypoint"]="mnist.py"
params["traininstancecount"]=2
params["traininstancetype"]="ml.c4.xlarge"
params["trainsourcefile"]="s3://{}/script.tar.gz".format(outputs["CodeBucket"])
params["hostentrypoint"]="mnist.py"
params["hostsourcefile"]="s3://{}/script.tar.gz".format(outputs["CodeBucket"])
params["pyversion"]="py2"
params["channels"]={
 "training":{
 "path":"train/mnist-dist",
 "dist":False
 }
}

ssm.put_parameter(
 Name=store,
 Type="String",
 Overwrite=True,
 Value=json.dumps(params)
)

### Upload the data
We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use this later when we start the training job.

In [None]:
dataBucket=outputs["DataBucket"]
!aws s3 cp ./data s3://$dataBucket/train/mnist-dist --recursive

In [None]:
%%time
result=sns.publish(
 TopicArn=outputs['LaunchTopic'],
 Message="{}" #message is not important, just publishing to topic starts build
)
print(result)
time.sleep(5)
#list all executions for our StateMachine to get our current running one
result=step.list_executions(
 stateMachineArn=outputs['StateMachine'],
 statusFilter="RUNNING"
)['executions']

if len(result) > 0:
 response = step.describe_execution(
 executionArn=result[0]['executionArn']
 )
 status=response['status']
 print(status,response['name'])
 #poll status till execution finishes
 while status == "RUNNING":
 print '.',
 time.sleep(5)
 status=step.describe_execution(executionArn=result[0]['executionArn'])['status']
 print()
 print(status)
else:
 print("no running tasks")

# Invoking the endpoint

In [None]:
import numpy as np
import google.protobuf.json_format as json_format
from tensorflow.examples.tutorials.mnist import input_data
from sagemaker.predictor import json_serializer, csv_serializer
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer
from sagemaker.predictor import RealTimePredictor

predict=RealTimePredictor(outputs["SageMakerEndpoint"], False, tf_json_serializer,tf_json_deserializer)
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

for i in range(10):
 data = mnist.test.images[i].tolist()
 tensor_proto = tf.make_tensor_proto(
 values=np.asarray(data), 
 shape=[1, len(data)], 
 dtype=tf.float32)
 
 
 predict_response=predict.predict(tensor_proto)
 print("========================================")
 label = np.argmax(mnist.test.labels[i])
 print("label is {}".format(label))
 prediction = predict_response['outputs']['classes']['int64Val'][0]
 print("prediction is {}".format(prediction))