# TensorFlow Training and using checkpointing on SageMaker Managed Spot Training

The example here is almost the same as [Train and Host a Keras Model with Pipe Mode and Horovod on Amazon SageMaker](https://github.com/aws/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/keras_script_mode_pipe_mode_horovod/tensorflow_keras_CIFAR10.ipynb).

This notebook tackles the exact same problem with the same solution, but it has been modified to be able to run using SageMaker Managed Spot infrastructure. SageMaker Managed Spot uses [EC2 Spot Instances](https://aws.amazon.com/ec2/spot/) to run Training at a lower cost.

Please read the original notebook and try it out to gain an understanding of the ML use-case and how it is being solved. We will not delve into that here in this notebook.


## Setup

First, we define a few variables that are be needed later in the example.

In [None]:
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

role = get_execution_role()
print('SageMaker version: ' + sagemaker.__version__)

## The CIFAR-10 dataset

The [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) is one of the most popular machine learning datasets. It consists of 60,000 32x32 images belonging to 10 different classes (6,000 images per class). Here are the classes in the dataset, as well as 10 random images from each:

![cifar10](https://maet3608.github.io/nuts-ml/_images/cifar10.png)

### Prepare the dataset for training

To use the CIFAR-10 dataset, we first download it and convert it to TFRecords. This step takes around 5 minutes.

In [None]:
!python generate_cifar10_tfrecords.py --data-dir ./data

Next, we upload the data to Amazon S3:

In [None]:
from sagemaker.s3 import S3Uploader

bucket = sagemaker_session.default_bucket()
prefix = 'tf-cifar10-example'

dataset_uri = S3Uploader.upload('data', 's3://{}/{}/data'.format(bucket,prefix))
print('Training Dataset location: {}'.format(dataset_uri))

## Train the model

In this tutorial, we train a deep CNN to learn a classification task with the CIFAR-10 dataset. We compare three different training jobs: a baseline training job, training with Pipe Mode, and distributed training with Horovod.

### Run a baseline training job on SageMaker

The SageMaker Python SDK's `sagemaker.tensorflow.TensorFlow` estimator class makes it easy for us to interact with SageMaker. We create one for each of the different training jobs we run in this example. A couple parameters worth noting:

* `entry_point`: our training script (adapted from [this Keras example](https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py)).
* `instance_count`: the number of training instances. Here, we set it to 1 for our baseline training job.

As we run each of our training jobs, we change different parameters to configure our different training jobs.

For more details about the TensorFlow estimator class, see the [API documentation](https://sagemaker.readthedocs.io/en/stable/sagemaker.tensorflow.html).

### Verify the training code

Before running the baseline training job, we first use [the SageMaker Python SDK's Local Mode feature](https://sagemaker.readthedocs.io/en/stable/overview.html#local-mode) to check that our code works with SageMaker's TensorFlow environment. Local Mode downloads the [prebuilt Docker image for TensorFlow](https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/deep-learning-containers-images.html) and runs a Docker container locally for a training job. In other words, it simulates the SageMaker environment for a quicker development cycle, so we use it here just to test out our code.

We create a TensorFlow estimator, and specify the `instance_type` to be `'local'` or `'local_gpu'`, depending on our local instance type. This tells the estimator to run our training job locally (as opposed to on SageMaker). We also have our training code run for only one epoch because our intent here is to verify the code, not train an accurate model.

In [None]:
import subprocess

from sagemaker.tensorflow import TensorFlow

instance_type = 'local'

if subprocess.call('nvidia-smi') == 0:
 # Set instance type to GPU if one is present
 instance_type = 'local_gpu'
 
local_hyperparameters = {'epochs': 1, 'batch-size' : 64}

estimator = TensorFlow(entry_point='cifar10_keras_main.py',
 source_dir='source_dir',
 role=role,
 framework_version='1.15.2',
 py_version='py3',
 hyperparameters=local_hyperparameters,
 instance_count=1,
 instance_type=instance_type)

Once we have our estimator, we call `fit()` to start the training job and pass the inputs that we downloaded earlier. We pass the inputs as a dictionary to define different data channels for training.

In [None]:
import os

data_path = os.path.join(os.getcwd(), 'data')

local_inputs = {
 'train': 'file://{}/train'.format(data_path),
 'validation': 'file://{}/validation'.format(data_path),
 'eval': 'file://{}/eval'.format(data_path),
}
estimator.fit(local_inputs)

### Run a baseline training job on SageMaker

Now we run training jobs on SageMaker, starting with our baseline training job.

### Configure metrics

In addition to running the training job, Amazon SageMaker can retrieve training metrics directly from the logs and send them to CloudWatch metrics. Here, we define metrics we would like to observe:

In [None]:
metric_definitions = [
 {'Name': 'train:loss', 'Regex': '.*loss: ([0-9\\.]+) - accuracy: [0-9\\.]+.*'},
 {'Name': 'train:accuracy', 'Regex': '.*loss: [0-9\\.]+ - accuracy: ([0-9\\.]+).*'},
 {'Name': 'validation:accuracy', 'Regex': '.*step - loss: [0-9\\.]+ - accuracy: [0-9\\.]+ - val_loss: [0-9\\.]+ - val_accuracy: ([0-9\\.]+).*'},
 {'Name': 'validation:loss', 'Regex': '.*step - loss: [0-9\\.]+ - accuracy: [0-9\\.]+ - val_loss: ([0-9\\.]+) - val_accuracy: [0-9\\.]+.*'},
 {'Name': 'sec/steps', 'Regex': '.* - \d+s (\d+)[mu]s/step - loss: [0-9\\.]+ - accuracy: [0-9\\.]+ - val_loss: [0-9\\.]+ - val_accuracy: [0-9\\.]+'}
]

Once again, we create a TensorFlow estimator, with a couple key modfications from last time:

* `instance_type`: the instance type for training. We set this to `ml.p3.2xlarge` because we are training on SageMaker now. For a list of available instance types, see [the AWS documentation](https://aws.amazon.com/sagemaker/pricing/instance-types).
* `metric_definitions`: the metrics (defined above) that we want sent to CloudWatch.

In [None]:
from sagemaker.tensorflow import TensorFlow

hyperparameters = {'epochs': 10, 'batch-size': 256}

estimator = TensorFlow(entry_point='cifar10_keras_main.py',
 source_dir='source_dir',
 metric_definitions=metric_definitions,
 hyperparameters=hyperparameters,
 role=role,
 framework_version='1.15.2',
 py_version='py3',
 instance_count=1,
 instance_type='ml.p3.2xlarge',
 base_job_name='cifar10-tf-on-demand')

Like before, we call `fit()` to start the SageMaker training job and pass the inputs in a dictionary to define different data channels for training. This time, we use the S3 URI from uploading our data.

In [None]:
inputs = {
 'train': '{}/train'.format(dataset_uri),
 'validation': '{}/validation'.format(dataset_uri),
 'eval': '{}/eval'.format(dataset_uri),
}

estimator.fit(inputs)

# Managed Spot Training with a TensorFlow Estimator

For Managed Spot Training using a TensorFlow Estimator we need to configure two things:
1. Enable the `use_spot_instances` constructor arg - a simple self-explanatory boolean.
2. Set the `max_wait` constructor arg - this is an int arg representing the amount of time you are willing to wait for Spot infrastructure to become available. Some instance types are harder to get at Spot prices and you may have to wait longer. You are not charged for time spent waiting for Spot infrastructure to become available, you're only charged for actual compute time spent once Spot instances have been successfully procured.

Normally, a third requirement would also be necessary here - modifying your code to ensure a regular checkpointing cadence - however, TensorFlow Estimators already do this, so no changes are necessary here. Checkpointing is highly recommended for Manage Spot Training jobs due to the fact that Spot instances can be interrupted with short notice and using checkpoints to resume from the last interruption ensures you don't lose any progress made before the interruption.

Feel free to toggle the `use_spot_instances` variable to see the effect of running the same job using regular (a.k.a. "On Demand") infrastructure.

Note that `max_wait` can be set if and only if `use_spot_instances` is enabled and **must** be greater than or equal to `max_run`.

In [None]:
use_spot_instances = True
max_run=600
max_wait = 1200

### Simulating Spot interruption after 5 epochs

Our training job should run on 10 epochs.

However, we will simulate a situation that after 5 epochs a spot interruption occurred.

Note the `checkpoint_s3_uri` variable which stores the S3 URI in which to persist checkpoints that the algorithm persists (if any) during training.

In [None]:
import uuid

checkpoint_suffix = str(uuid.uuid4())[:8]
checkpoint_s3_uri = 's3://{}/checkpoint-{}'.format(bucket, checkpoint_suffix)

print('Checkpointing location: {}'.format(checkpoint_s3_uri))

The goal is that the checkpointing data will be copied to S3, so when there is a spot capacity available again, the training job can resume from the 6th epoch.

In [None]:
hyperparameters = {'epochs': 5, 'batch-size': 256}

spot_estimator = TensorFlow(entry_point='cifar10_keras_main.py',
 source_dir='source_dir',
 metric_definitions=metric_definitions,
 hyperparameters=hyperparameters,
 role=role,
 framework_version='1.15.2',
 py_version='py3',
 instance_count=1,
 instance_type='ml.p3.2xlarge',
 base_job_name='cifar10-tf-spot-1st-run',
 checkpoint_s3_uri=checkpoint_s3_uri,
 use_spot_instances=use_spot_instances,
 max_run=max_run,
 max_wait=max_wait)

In [None]:
spot_estimator.fit(inputs)

### Savings
Towards the end of the job you should see two lines of output printed:

- `Training seconds: X` : This is the actual compute-time your training job spent
- `Billable seconds: Y` : This is the time you will be billed for after Spot discounting is applied.

If you enabled the `use_spot_instances` var then you should see a notable difference between `X` and `Y` signifying the cost savings you will get for having chosen Managed Spot Training. This should be reflected in an additional line:
- `Managed Spot Training savings: (1-Y/X)*100 %`

### Analyze training job logs

Analyzing the training job logs, we can see that the training job starts from the 1st epoch:

```
INFO:root:Starting training from epoch: 1
```

### View the job training Checkpoint configuration

We can now view the Checkpoint configuration from the training job directly in the SageMaker console. 

Log into the [SageMaker console](https://console.aws.amazon.com/sagemaker/home), choose the latest training job, and scroll down to the Checkpoint configuration section. 

Choose the S3 output path link and you'll be directed to the S3 bucket were checkpointing data is saved.

You can see there are 5 files there:

```python
checkpoint-1.h5
checkpoint-2.h5
checkpoint-3.h5
checkpoint-4.h5
checkpoint-5.h5
```

### Continue training after Spot capacity is resumed

Now we simulate a situation where Spot capacity is resumed.

We will start a training job again, this time with 10 epochs.

What we expect is that the tarining job will start from the 6th epoch.

This is done when training job starts. It checks the checkpoint s3 location for checkpoints data. If there are, they are copied to `/opt/ml/checkpoints` on the training conatiner.

In the code you can see the `load_model_from_checkpoints` function to load the checkpoints data:

```python
def load_model_from_checkpoints(checkpoint_path):
 checkpoint_files = [file for file in os.listdir(checkpoint_path) if file.endswith('.' + 'h5')]
 logging.info('------------------------------------------------------')
 logging.info("Available checkpoint files: {}".format(checkpoint_files))
 epoch_numbers = [re.search('(\.*[0-9])(?=\.)',file).group() for file in checkpoint_files]
 
 max_epoch_number = max(epoch_numbers)
 max_epoch_index = epoch_numbers.index(max_epoch_number)
 max_epoch_filename = checkpoint_files[max_epoch_index]

 logging.info('Latest epoch checkpoint file name: {}'.format(max_epoch_filename))
 logging.info('Resuming training from epoch: {}'.format(int(max_epoch_number)+1))
 logging.info('------------------------------------------------------')
 
 resumed_model_from_checkpoints = load_model(f'{checkpoint_path}/{max_epoch_filename}')
 return resumed_model_from_checkpoints, int(max_epoch_number)
```


In [None]:
hyperparameters = {'epochs': 10, 'batch-size': 256}

spot_estimator = TensorFlow(entry_point='cifar10_keras_main.py',
 source_dir='source_dir',
 metric_definitions=metric_definitions,
 hyperparameters=hyperparameters,
 role=role,
 framework_version='1.15.2',
 py_version='py3',
 instance_count=1,
 instance_type='ml.p3.2xlarge',
 base_job_name='cifar10-tf-spot-2nd-run',
 checkpoint_s3_uri=checkpoint_s3_uri,
 use_spot_instances=use_spot_instances,
 max_run=max_run,
 max_wait=max_wait)

In [None]:
spot_estimator.fit(inputs)

### Analyze training job logs

Analyzing the training job logs, we can see that now, the training job starts from the 6th epoch.

We can see the output of `load_model_from_checkpoints` function:

```
INFO:root:------------------------------------------------------
INFO:root:Available checkpoint files: ['checkpoint-1.h5', 'checkpoint-4.h5', 'checkpoint-3.h5', 'checkpoint-2.h5', 'checkpoint-5.h5']
INFO:root:Latest epoch checkpoint file name: checkpoint-5.h5
INFO:root:Resuming training from epoch: 6
INFO:root:------------------------------------------------------
```

Going further down in the logs, we can now see the following line indicating training job will start from the 6th epoch:
```
INFO:root:Starting training from epoch: 6
```

### View the job training Checkpoint configuration after job completed 10 epochs

We can now view the Checkpoint configuration from the training job directly in the SageMaker console. 

Log into the [SageMaker console](https://console.aws.amazon.com/sagemaker/home), choose the latest training job, and scroll down to the Checkpoint configuration section. 

Choose the S3 output path link and you'll be directed to the S3 bucket were checkpointing data is saved.

You can see there are 10 files there:

```python
checkpoint-1.h5
checkpoint-2.h5
checkpoint-3.h5
checkpoint-4.h5
checkpoint-5.h5
checkpoint-6.h5
checkpoint-7.h5
checkpoint-8.h5
checkpoint-9.h5
checkpoint-10.h5
```

You'll be able to see that the dates of the first five checkpoint files (1-5), and the second group (6-10) are grouped together, indicating the different time where the training job was run.

## Deploy the trained model

After we train our model, we can deploy it to a SageMaker Endpoint, which serves prediction requests in real-time. To do so, we simply call `deploy()` on our estimator, passing in the desired number of instances and instance type for the endpoint.

Because we're using TensorFlow Serving for deployment, our training script saves the model in TensorFlow's SavedModel format. For more details, see [this blog post on deploying Keras and TF models in SageMaker](https://aws.amazon.com/blogs/machine-learning/deploy-trained-keras-or-tensorflow-models-using-amazon-sagemaker).

In [None]:
predictor = spot_estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

### Invoke the endpoint

To verify the that the endpoint is in service, we generate some random data in the correct shape and get a prediction.

In [None]:
import numpy as np

data = np.random.randn(1, 32, 32, 3)
print('Predicted class: {}'.format(np.argmax(predictor.predict(data)['predictions'])))

Now let's use the test dataset for predictions.

In [None]:
from keras.datasets import cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

With the data loaded, we can use it for predictions:

In [None]:
from keras.preprocessing.image import ImageDataGenerator

def predict(data):
 predictions = predictor.predict(data)['predictions']
 return predictions


predicted = []
actual = []
batches = 0
batch_size = 128

datagen = ImageDataGenerator()
for data in datagen.flow(x_test, y_test, batch_size=batch_size):
 for i, prediction in enumerate(predict(data[0])):
 predicted.append(np.argmax(prediction))
 actual.append(data[1][i][0])

 batches += 1
 if batches >= len(x_test) / batch_size:
 break

With the predictions, we calculate our model accuracy and create a confusion matrix.

In [None]:
from sklearn.metrics import accuracy_score

accuracy = accuracy_score(y_pred=predicted, y_true=actual)
display('Average accuracy: {}%'.format(round(accuracy * 100, 2)))

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sn
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_pred=predicted, y_true=actual)
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sn.set(rc={'figure.figsize': (11.7,8.27)})
sn.set(font_scale=1.4) # for label size
sn.heatmap(cm, annot=True, annot_kws={"size": 10}) # font size

Aided by the colors of the heatmap, we can use this confusion matrix to understand how well the model performed for each label.

## Cleanup

To avoid incurring extra charges to your AWS account, let's delete the endpoint we created:

In [None]:
predictor.delete_endpoint()