# Train an MNIST model with TensorFlow


---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/frameworks|tensorflow|get_started_mnist_train.ipynb)

---


MNIST is a widely-used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial will show how to train a TensorFlow V2 model on MNIST model on SageMaker.

## Runtime

This notebook takes approximately 5 minutes to run.

## Contents

1. [TensorFlow Estimator](#TensorFlow-Estimator)
1. [Implement the training entry point](#Implement-the-training-entry-point)
1. [Set hyperparameters](#Set-hyperparameters)
1. [Set up channels for training and testing data](#Set-up-channels-for-training-and-testing-data)
1. [Run the training script on SageMaker](#Run-the-training-script-on-SageMaker)
1. [Inspect and store model data](#Inspect-and-store-model-data)
1. [Test and debug the entry point before running the training container](#Test-and-debug-the-entry-point-before-running-the-training-container)

In [None]:
import os
import json

import sagemaker
from sagemaker.tensorflow import TensorFlow
from sagemaker import get_execution_role

sess = sagemaker.Session()

role = get_execution_role()

output_path = "s3://" + sess.default_bucket() + "/DEMO-tensorflow/mnist"

## TensorFlow Estimator

The `TensorFlow` class allows you to run your training script on SageMaker
infrastracture in a containerized environment. In this notebook, we
refer to this container as the "training container." 

Configure it with the following parameters to set up the environment:

- `entry_point`: A user-defined Python file used by the training container as the instructions for training. We will further discuss this file in the next subsection.

- `role`: An IAM role to make AWS service requests

- `instance_type`: The type of SageMaker instance to run your training script. Set it to `local` if you want to run the training job on the SageMaker instance you are using to run this notebook.

- `model_dir`: S3 bucket URI where the checkpoint data and models can be exported to during training (default: None). 
To disable having model_dir passed to your training script, set `model_dir`=False

- `instance_count`: The number of instances to run your training job on. Multiple instances are needed for distributed training.

- `output_path`: the S3 bucket URI to save training output (model artifacts and output files).

- `framework_version`: The TensorFlow version to use.

- `py_version`: The Python version to use.

For more information, see the [EstimatorBase API reference](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.EstimatorBase).



## Implement the training entry point

The entry point for training is a Python script that provides all 
the code for training a TensorFlow model. It is used by the SageMaker 
TensorFlow Estimator (`TensorFlow` class above) as the entry point for running the training job.

Under the hood, SageMaker TensorFlow Estimator downloads a docker image
with runtime environments 
specified by the parameters to initiate the
estimator class and it injects the training script into the 
docker image as the entry point to run the container.

In the rest of the notebook, we use *training image* to refer to the 
docker image specified by the TensorFlow Estimator and *training container*
to refer to the container that runs the training image. 

This means your training script is very similar to a training script
you might run outside Amazon SageMaker, but it can access the useful environment 
variables provided by the training image. See [the complete list of environment variables](https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md) for a complete 
description of all environment variables your training script
can access. 

In this example, we use the training script `code/train.py`
as the entry point for our TensorFlow Estimator. 

In [None]:
!pygmentize 'code/train.py'

## Set hyperparameters

In addition, the TensorFlow estimator allows you to parse command line arguments
to your training script via `hyperparameters`.

<span style="color:red"> Note: local mode is not supported in SageMaker Studio. </span>

In [None]:
# Set local_mode to be True if you want to run the training script on the machine that runs this notebook

local_mode = False

if local_mode:
    instance_type = "local"
else:
    instance_type = "ml.c4.xlarge"

est = TensorFlow(
    entry_point="train.py",
    source_dir="code",  # directory of your training script
    role=role,
    framework_version="2.3.1",
    model_dir=False,  # don't pass --model_dir to your training script
    py_version="py37",
    instance_type=instance_type,
    instance_count=1,
    volume_size=250,
    output_path=output_path,
    hyperparameters={
        "batch-size": 512,
        "epochs": 1,
        "learning-rate": 1e-3,
        "beta_1": 0.9,
        "beta_2": 0.999,
    },
)

The training container runs your training script like:

```
python train.py --batch-size 32 --epochs 1 --learning-rate 0.001 --beta_1 0.9 --beta_2 0.999
```

## Set up channels for training and testing data

Tell `TensorFlow` estimator where to find the training and 
testing data. It can be a path to an S3 bucket, or a path
in your local file system if you use local mode. In this example,
we download the MNIST data from a public S3 bucket and upload it 
to your default bucket. 

In [None]:
import logging
import boto3
from botocore.exceptions import ClientError

# Download training and testing data from a public S3 bucket


def download_from_s3(data_dir="./data", train=True):
    """Download MNIST dataset and convert it to numpy array

    Args:
        data_dir (str): directory to save the data
        train (bool): download training set

    Returns:
        None
    """

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    if train:
        images_file = "train-images-idx3-ubyte.gz"
        labels_file = "train-labels-idx1-ubyte.gz"
    else:
        images_file = "t10k-images-idx3-ubyte.gz"
        labels_file = "t10k-labels-idx1-ubyte.gz"

    # download objects
    s3 = boto3.client("s3")
    bucket = f"sagemaker-example-files-prod-{boto3.session.Session().region_name}"
    for obj in [images_file, labels_file]:
        key = os.path.join("datasets/image/MNIST", obj)
        dest = os.path.join(data_dir, obj)
        if not os.path.exists(dest):
            s3.download_file(bucket, key, dest)
    return


download_from_s3("./data", True)
download_from_s3("./data", False)

In [None]:
# Upload to the default bucket

prefix = "DEMO-mnist"
bucket = sess.default_bucket()
loc = sess.upload_data(path="./data", bucket=bucket, key_prefix=prefix)

channels = {"training": loc, "testing": loc}

The keys of the `channels` dictionary are passed to the training image,
and it creates the environment variable `SM_CHANNEL_<key name>`. 

In this example, `SM_CHANNEL_TRAINING` and `SM_CHANNEL_TESTING` are created in the training image (see 
how `code/train.py` accesses these variables). For more information,
see: [SM_CHANNEL_{channel_name}](https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md#sm_channel_channel_name).

If you want, you can create a channel for validation:
```
channels = {
    'training': train_data_loc,
    'validation': val_data_loc,
    'test': test_data_loc
}
```
You can then access this channel within your training script via
`SM_CHANNEL_VALIDATION`.

## Run the training script on SageMaker
Now, the training container has everything to run your training
script. Start the container by calling the `fit()` method.

In [None]:
est.fit(inputs=channels)

## Inspect and store model data

Now, the training is finished, and the model artifact has been saved in 
the `output_path`.

In [None]:
tf_mnist_model_data = est.model_data
print("Model artifact saved at:\n", tf_mnist_model_data)

We store the variable `tf_mnist_model_data` in the current notebook kernel. 

In [None]:
%store tf_mnist_model_data

## Test and debug the entry point before running the training container

The entry point `code/train.py` provided here has been tested and it can be runs in the training container. 
When you develop your own training script, it is a good practice to simulate the container environment 
in the local shell and test it before sending it to SageMaker, because debugging in a containerized environment
is rather cumbersome. The following script shows how you can test your training script:

In [None]:
!pygmentize code/test_train.py

## Conclusion

In this notebook, we trained a TensorFlow model on the MNIST dataset by fitting a SageMaker estimator. For next steps on how to deploy the trained model and perform inference, see [Deploy a Trained TensorFlow V2 Model](https://sagemaker-examples.readthedocs.io/en/latest/frameworks/tensorflow/get_started_mnist_deploy.html).

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/frameworks|tensorflow|get_started_mnist_train.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/frameworks|tensorflow|get_started_mnist_train.ipynb)
