# Data Parallel Training

This notebook trains the ENet model on a number of GPUs distributed across multiple `ml.p3.16xlarge` instances
using [SageMaker's Distributed Data Parallel](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel.html) library.

A prerequisite for model training is a preprocessed dataset which is done in a [separate notebook](preprocess-camvid.ipynb).

## Imports and Paths

The next cell imports modules from the [Amazon SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/)
that we need for training the model, sets up a SageMaker session,
and then defines the S3 URIs for the preprocessed data.

In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext dotenv
%dotenv

import sagemaker
from sagemaker.tensorflow import TensorFlow
from sagemaker.inputs import TrainingInput

session = sagemaker.Session()
bucket = session.default_bucket()
role = sagemaker.get_execution_role()
training_role = role

prefix = 'enet-tensorflow-distributed'
train_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/train/'
train_labels_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/train_labels/'
val_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/val/'
val_labels_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/val_labels/'
test_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/test/'
test_labels_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/test_labels/'
report_path = f's3://{bucket}/{prefix}/preprocessed-data/camvid/report/'
preprocessing_report_path = f'{report_path}preprocessing_report.json'
class_dict_path = f'{report_path}class_dict.json'

## Define Training Job

Since the ENet model is implemented in TensorFlow, we're using the [`TensorFlow estimator`](https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/sagemaker.tensorflow.html) to train it via Amazon SageMaker
using a custom [training script](../scripts/train_data_parallel.py) (set via the `source_dir` and `entry_point` arguments).

We also set the model's hyperparameters,
as well as metric definitions that allow us to extract training metrics from log output.

For cost efficiency we're using [managed spot training](https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html) (by setting `use_spot_instances=True` and providing `max_run` and `max_wait`).

For data parallel training we provide a [`distribution`](https://sagemaker.readthedocs.io/en/stable/api/training/smd_data_parallel_use_sm_pysdk.html) argument which configures the distributed training.

Note that the training job runs on two `ml.p3.16xlarge` instances (`instance_count=2`).

In [8]:
hyperparameters = {
 'dropout-rate1': 0.01,
 'dropout-rate2': 0.1,
 'batch-size': 4,
 'learning-rate': 0.001,
 'epochs': 25,
}
metric_definitions = [
 {'Name': 'Epoch', 'Regex': r'# epoch = (\d+)'},
 {'Name': 'Loss', 'Regex': r'# loss = ([\d.\-\+e]+)'},
 {'Name': 'Val Loss', 'Regex': r'# val_loss = ([\d.\-\+e]+)'},
 {'Name': 'Mean IoU', 'Regex': r'# mean_iou = ([\d.\-\+e]+)'},
 {'Name': 'Val Mean IoU', 'Regex': r'# val_mean_iou = ([\d.\-\+e]+)'},
]
estimator = TensorFlow(
 base_job_name='enet-tf-dp-train',
 py_version='py39',
 framework_version='2.8.0',
 model_dir='/opt/ml/model',
 checkpoint_local_path='/opt/ml/checkpoints',
 entry_point='scripts/train_data_parallel.py',
 source_dir='../',
 hyperparameters=hyperparameters,
 metric_definitions=metric_definitions,
 role=training_role,
 sagemaker_session=session,
 instance_count=2,
 instance_type='ml.p3.16xlarge',
 distribution={
 'smdistributed': {
 'dataparallel': {
 'enabled': True,
 'custom_mpi_options': '-verbose -x NCCL_DEBUG=VERSION'
 }
 }
 },
 use_spot_instances=True,
 max_run=10*3600,
 max_wait=16*3600,
)


## Run Training Job

We then run the training job by invoking the TensorFlow estimator's `fit` method.
As argument we provide the data [inputs](https://sagemaker.readthedocs.io/en/stable/api/utility/inputs.html) with the locations of the preprocessed dataset in S3.

In [9]:
estimator.fit({
 'train': TrainingInput(train_path),
 'train_labels': TrainingInput(train_labels_path),
 'val': TrainingInput(val_path),
 'val_labels': TrainingInput(val_labels_path),
 'test': TrainingInput(test_path),
 'test_labels': TrainingInput(test_labels_path),
 'report': TrainingInput(report_path),
}, wait=False)

During training we can stream the logs of the training job to the notebook to follow its progress.

In [None]:
estimator.logs()