# Apache MXNet Training and using checkpointing on SageMaker Managed Spot Training

The example here is almost the same as [Training and hosting SageMaker Models using the Apache MXNet Module API](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_mnist/mxnet_mnist.ipynb).

This notebook tackles the exact same problem with the same solution, but it has been modified to be able to run using SageMaker Managed Spot infrastructure. SageMaker Managed Spot uses [EC2 Spot Instances](https://aws.amazon.com/ec2/spot/) to run Training at a lower cost.

Please read the original notebook and try it out to gain an understanding of the ML use-case and how it is being solved. We will not delve into that here in this notebook.

## First setup variables and define functions

Again, we won't go into detail explaining the code below, it has been lifted verbatim from [Training and hosting SageMaker Models using the Apache MXNet Module API](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_mnist/mxnet_mnist.ipynb)

In [None]:
!pip install sagemaker -U

In [None]:
import sagemaker
import boto3
from sagemaker import get_execution_role
from sagemaker.session import Session
import uuid

# S3 bucket for saving code and model artifacts.
# Feel free to specify a different bucket here if you wish.
bucket = Session().default_bucket()

# Location to save your custom code in tar.gz format.
custom_code_upload_location = 's3://{}/customcode/mxnet'.format(bucket)

# Location where results of model training are saved.
model_artifacts_location = 's3://{}/artifacts'.format(bucket)

# IAM execution role that gives SageMaker access to resources in your AWS account.
# We can use the SageMaker Python SDK to get the role from our notebook environment. 
role = get_execution_role()

region = boto3.Session().region_name
train_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/train'.format(region)
test_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/test'.format(region)

checkpoint_suffix = str(uuid.uuid4())[:8]
checkpoint_s3_uri = 's3://{}/artifacts/mxnet-checkpoint-{}/'.format(bucket, checkpoint_suffix)

print('SageMaker version: ' + sagemaker.__version__)
print('Checkpointing Path: {}'.format(checkpoint_s3_uri))

# Managed Spot Training with MXNet

For Managed Spot Training using MXNet we need to configure three things:
1. Enable the `use_spot_instances` constructor arg - a simple self-explanatory boolean.
2. Set the `max_wait` constructor arg - this is an int arg representing the amount of time you are willing to wait for Spot infrastructure to become available. Some instance types are harder to get at Spot prices and you may have to wait longer. You are not charged for time spent waiting for Spot infrastructure to become available, you're only charged for actual compute time spent once Spot instances have been successfully procured.
3. Setup a `checkpoint_s3_uri` constructor arg. This arg will tell SageMaker an S3 location where to save checkpoints (assuming your algorithm has been modified to save checkpoints periodically). While not strictly necessary checkpointing is highly recommended for Manage Spot Training jobs due to the fact that Spot instances can be interrupted with short notice and using checkpoints to resume from the last interruption ensures you don't lose any progress made before the interruption.

Feel free to toggle the `use_spot_instances` variable to see the effect of running the same job using regular (a.k.a. "On Demand") infrastructure.

Note that `max_wait` can be set if and only if `use_spot_instances` is enabled and **must** be greater than or equal to `max_run`.

In [None]:
use_spot_instances = True
max_run=600
max_wait = 1200 if use_spot_instances else None

### Simulating Spot interruption after 5 epochs

Our training job should run on 10 epochs.

However, we will simulate a situation that after 5 epochs a spot interruption occurred.

The goal is that the checkpointing data will be copied to S3, so when there is a spot capacity available again, the training job can resume from the 6th epoch.

Note the `checkpoint_s3_uri` variable which stores the S3 URI in which to persist checkpoints that the algorithm persists (if any) during training.

In [None]:
from sagemaker.mxnet import MXNet

mnist_estimator = MXNet(entry_point='source_dir/mnist.py',
                        role=role,
                        output_path=model_artifacts_location,
                        code_location=custom_code_upload_location,
                        instance_count=1,
                        instance_type='ml.m4.xlarge',
                        framework_version='1.6.0',
                        py_version='py3',
                        distribution={'parameter_server': {'enabled': True}},
                        hyperparameters={'learning-rate': 0.1, 'epochs': 5},
                        use_spot_instances=use_spot_instances,
                        max_run=max_run,
                        max_wait=max_wait,
                        checkpoint_s3_uri=checkpoint_s3_uri)
mnist_estimator.fit({'train': train_data_location, 'test': test_data_location})

# Savings
Towards the end of the job you should see two lines of output printed:

- `Training seconds: X` : This is the actual compute-time your training job spent
- `Billable seconds: Y` : This is the time you will be billed for after Spot discounting is applied.

If you enabled the `use_spot_instances` var then you should see a notable difference between `X` and `Y` signifying the cost savings you will get for having chosen Managed Spot Training. This should be reflected in an additional line:
- `Managed Spot Training savings: (1-Y/X)*100 %`

### Analyze training job logs

Analyzing the training job logs, we can see that the training job starts from the 1st epoch:

```
INFO:root:Starting training from epoch: 0
```

### View the job training Checkpoint configuration

We can now view the Checkpoint configuration from the training job directly in the SageMaker console.  

Log into the [SageMaker console](https://console.aws.amazon.com/sagemaker/home), choose the latest training job, and scroll down to the Checkpoint configuration section. 

Choose the S3 output path link and you'll be directed to the S3 bucket were checkpointing data is saved.

You can see there are 11 files there:

```python
mnist-symbol.json 
mnist-0005.states 
mnist-0005.params 
mnist-0004.states 
mnist-0004.params 
mnist-0003.states 
mnist-0003.params 
mnist-0002.states 
mnist-0002.params 
mnist-0001.states 
mnist-0001.params 
```

Those files store the trainer states, model parameters, and model architecture.

### Continue training after Spot capacity is resumed

Now we simulate a situation where Spot capacity is resumed.

We will start a training job again, this time with 10 epochs.

What we expect is that the tarining job will start from the 6th epoch (will show up as epcoh 5 in MXNet logs).

This is done when training job starts. It checks the checkpoint s3 location for checkpoints data. If there are, they are copied to `/opt/ml/checkpoints` on the training conatiner.

In the code you can see the `load_model_from_checkpoints` function to load the checkpoints data:

```python
def load_model_from_checkpoints(checkpoint_path):
    checkpoint_files = [file for file in os.listdir(checkpoint_path) if file.endswith('.' + 'params')]
    logging.info('------------------------------------------------------')
    logging.info("Available checkpoint files: {}".format(checkpoint_files))
    epoch_numbers = [re.search('(\.*[0-9])(?=\.)',file).group() for file in checkpoint_files]
      
    max_epoch_number = max(epoch_numbers)
    max_epoch_index = epoch_numbers.index(max_epoch_number)
    max_epoch_filename = checkpoint_files[max_epoch_index]

    logging.info('Latest epoch checkpoint file name: {}'.format(max_epoch_filename))
    logging.info('Resuming training from epoch: {}'.format(max_epoch_number))
    logging.info('------------------------------------------------------')
    
    sym, arg_params, aux_params = mx.model.load_checkpoint(checkpoint_path + "/mnist", int(max_epoch_number))
    mlp_model = mx.mod.Module(symbol=sym)
    return mlp_model, int(max_epoch_number)
```


In [None]:
mnist_estimator = MXNet(entry_point='source_dir/mnist.py',
                        role=role,
                        output_path=model_artifacts_location,
                        code_location=custom_code_upload_location,
                        instance_count=1,
                        instance_type='ml.m4.xlarge',
                        framework_version='1.6.0',
                        py_version='py3',
                        distribution={'parameter_server': {'enabled': True}},
                        hyperparameters={'learning-rate': 0.1, 'epochs': 10},
                        use_spot_instances=use_spot_instances,
                        max_run=max_run,
                        max_wait=max_wait,
                        checkpoint_s3_uri=checkpoint_s3_uri)
mnist_estimator.fit({'train': train_data_location, 'test': test_data_location})

### Analyze training job logs

Analyzing the training job logs, we can see that now, the training job starts from the 6th epoch.

We can see the output of `load_model_from_checkpoints` function:

```
INFO:root:------------------------------------------------------
INFO:root:Available checkpoint files: ['mnist-0005.params', 'mnist-0001.params', 'mnist-0003.params', 'mnist-0004.params', 'mnist-0002.params']
INFO:root:Latest epoch checkpoint file name: mnist-0005.params
INFO:root:Resuming training from epoch: 5
INFO:root:------------------------------------------------------
```

Going further down in the logs, we can now see the following line indicating training job will start from the 6th epoch (shows up as epcoh 5 in MXNet logs):
```
INFO:root:Starting training from epoch: 5
```

### View the job training Checkpoint configuration after job completed 10 epochs

We can now view the Checkpoint configuration from the training job directly in the SageMaker console.  

Log into the [SageMaker console](https://console.aws.amazon.com/sagemaker/home), choose the latest training job, and scroll down to the Checkpoint configuration section. 

Choose the S3 output path link and you'll be directed to the S3 bucket were checkpointing data is saved.

You can see there are 21 files there:

```python
mnist-symbol.json 
mnist-0010.states
mnist-0010.params
mnist-0009.states 
mnist-0009.params 
mnist-0008.states 
mnist-0008.params 
mnist-0007.states 
mnist-0007.params 
mnist-0006.states 
mnist-0006.params 
mnist-0005.states 
mnist-0005.params 
mnist-0004.states 
mnist-0004.params 
mnist-0003.states 
mnist-0003.params 
mnist-0002.states 
mnist-0002.params 
mnist-0001.states 
mnist-0001.params 
```

You'll be able to see that the dates of the first ten checkpoint files (1-5), and the second group (6-10) and mnist-symbol.json are grouped together, indicating the different time where the training job was run.