This notebook is developed using ml.t3.medium instance with `Python 3 (Data Science)` kernel on SageMaker Studio.

Import SageMaker SDK and Create a Session

In [1]:
import boto3
import sagemaker
import time
from time import gmtime, strftime

session = sagemaker.Session()
role = sagemaker.get_execution_role()
aws_region = session.boto_region_name

# Project Bucket
bucket = session.default_bucket()
dataset_prefix = 'medical-imaging/dataset'
scaled_dataset_prefix = 'medical-imaging/scaled_dataset'
scaled_zipped_dataset_prefix = 'medical-imaging/scaled_zipped_dataset'

### VPC setup: subnets/SGs options

In [2]:
# please specify valid vpc subnet ID and security group ID to train within your
# VPC of choice. This is required for model training with Amazon FSx for lustre.
vpc_subnet_ids = ['subnet-xxxxxxxxxx']
security_group_ids = ['sg-xxxxxxxxxx']

### Define a SageMaker `PyTorch Estimator`

In [3]:
from sagemaker.pytorch import PyTorch
def get_pytorch_estimator(entry_point, hyperparameters, instance_type, 
 instance_count, output_prefix, 
 dist_training_config=None, volume_size=10, 
 subnets=None, security_group_ids=None):
 pt_estimator = PyTorch(
 role=role,
 sagemaker_session=session,
 subnets=subnets,
 security_group_ids=security_group_ids,

 source_dir='src',
 entry_point=entry_point,
 hyperparameters=hyperparameters,
 py_version='py36',
 framework_version='1.6.0',

 instance_count=instance_count,
 instance_type=instance_type,
 volume_size=volume_size,

 enable_sagemaker_metrics=True,
 metric_definitions=metric_def,

 debugger_hook_config=False,
 disable_profiler=True,
 distribution=dist_training_config,

 code_location=f's3://{bucket}/{output_prefix}/output',
 output_path=f's3://{bucket}/{output_prefix}/output',
 max_run=432000 # Max runtime of of 5 days
 )
 
 return pt_estimator

# Training loop metrics to persist
metric_def = [
 {
 "Name": "train_loss",
 "Regex": "train_loss: (.*?)$",
 },
 {
 "Name": "average_loss",
 "Regex": "average loss: (.*?)$",
 },
 {
 "Name": "mean_dice",
 "Regex": "current mean dice: (.*?) ",
 },
 {
 "Name": "time_per_epoch",
 "Regex": "secs_time_per_epoch: (.*?)$",
 },
 {
 "Name": "dice_tc",
 "Regex": "tc: (.*?) ",
 },
 {
 "Name": "dice_wt",
 "Regex": "wt: (.*?) ",
 },
 {
 "Name": "dice_et",
 "Regex": "et: (.*?)$",
 },
]

### Single GPU Device Experiments - Original Dataset 484 training pairs (4.65 GB)

Run training for three of MONAI's dataset classes:
1. `Dataset`: standard data loading
2. `PersistentDataset`: persist processed data on disk
2. `CacheDataset`: persist processed data in CPU memory

In [4]:
training_data_on_s3 = "s3://{}/{}/Task01_BrainTumour".format(bucket, dataset_prefix)

hyperparameters = {
 'torch_dataset_type': "Dataset",
 'lr': 5e-3,
 'epochs': 10,
 'batch_size': 16,
 'num_workers': 4
}
 
 
for dataset_type in ['Dataset', 'PersistentDataset', 'CacheDataset']:

 hyperparameters["torch_dataset_type"] = dataset_type
 
 # Instanciate a training container with pytorch image
 WORKFLOW_DATE_TIME = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
 output_prefix = "brats_ebs/{}/{}/sagemaker".format(WORKFLOW_DATE_TIME, dataset_type)
 pt_estimator = get_pytorch_estimator('single_gpu_training.py', 
 hyperparameters, 
 'ml.p3.2xlarge', 
 1, 
 output_prefix, 
 dist_training_config=None, 
 volume_size=100)


 # Luanch training job
 pt_estimator.fit(
 job_name='monai-1gpu-{}-{}'.format(dataset_type, WORKFLOW_DATE_TIME),
 inputs={'train':training_data_on_s3},
 wait=False
 )
 time.sleep(1)

#### Results:
The above runs should produce 3 training jobs. Visit the SageMaker training jobs for details on each. The `CacheDataset` run should be the fasttest, followed by `PersistentDataset` and `Dataset`

### Single GPU Device Experiment - Synthetically Scaled Dataset 48,400 training pairs (~450GB compressed, ~7TB decompressed)
 4. `Dataset`: standard data loading but on large dataset
 
 Please make sure your account has 3,000GB quota for **[Size of EBS volume for an instance](https://docs.aws.amazon.com/general/latest/gr/sagemaker.html)** for a SageMaker training job. Please visit [AWS service quotas](https://docs.aws.amazon.com/general/latest/gr/aws_service_limits.html) page for requesting a quota increase.

In [5]:
scaled_training_data_on_s3 = "s3://{}/{}/".format(bucket, scaled_dataset_prefix)

hyperparameters = {
 'torch_dataset_type': "Dataset",
 'lr': 5e-3,
 'epochs': 10,
 'batch_size': 16,
 'num_workers': 4
}

# Instanciate a training container with pytorch image
WORKFLOW_DATE_TIME = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
output_prefix = "brats_ebs/{}/{}/sagemaker".format(WORKFLOW_DATE_TIME, "ScaledDataset")
pt_estimator = get_pytorch_estimator('single_gpu_training.py', 
 hyperparameters, 
 'ml.p3.2xlarge', 
 1, 
 output_prefix, 
 dist_training_config=None, 
 volume_size=3000)

# Luanch training job 
pt_estimator.fit(
 job_name='monai-1gpu-ScaledDataset-{}'.format(WORKFLOW_DATE_TIME),
 inputs={'train':scaled_training_data_on_s3},
 wait=False
)

#### Results:
The above run should produce 1 SageMaker training job. Visit the SageMaker training jobs for details on each.

The average epoch should take around 31,000 seconds or 8 hours 37 minutes. **The entire 10 epoch training job should take close to four days to finish.** Moreover, we run into disk and memory limits if we try to use `PersistentDataset` or `CacheDataset`.

### Solution:
[SageMaker Distributed Data Parallel Training Library](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel.html) + SageMaker Processing + FSx

### Step 1: SageMaker managed distributed image pre-processing
Define data processing infrastructure 

In [6]:
from sagemaker.processing import ScriptProcessor
script_processor = ScriptProcessor(
 image_uri="763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:1.8.1-cpu-py36-ubuntu18.04",
 instance_count=20,
 instance_type='ml.c4.8xlarge',
 volume_size_in_gb=1024,
 role=sagemaker.get_execution_role(),
 command=['python3']
)

Define input/output paths

In [7]:
from sagemaker.processing import ProcessingInput, ProcessingOutput
PROCESSING_JOB_DATETIME = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
transformed_data_prefix = "medical-imaging/transformed_scaled_dataset/"+PROCESSING_JOB_DATETIME+"/train"
processing_s3_output_path = "s3://{}/{}".format(bucket, transformed_data_prefix)


ScriptProcessorOutput = [
 ProcessingOutput(
 output_name='train',
 destination=processing_s3_output_path,
 source='/opt/ml/processing/train'
 )
]

ScriptProcessorInput = [
 ProcessingInput(
 source="s3://{}/{}/".format(bucket,scaled_zipped_dataset_prefix),
 destination='/opt/ml/processing/input',
 s3_data_distribution_type='ShardedByS3Key'
 )
]

Run processing job

In [None]:
script_processor.run(job_name="brats-sharded-preprocessing-{}".format(WORKFLOW_DATE_TIME),
 code='src/sharded_image_preprocessing.py',
 inputs=ScriptProcessorInput,
 outputs=ScriptProcessorOutput,
 wait=True
 )

### Step 2: FSx for Lustre
Upon completion, the SageMaker distributed data processing job writes transformed data to `scaled_training_data_on_s3_combined`. To expedite data transfer from S3 to training hosts, we create a high performant file system using Amazon FSx for Lustre. *Note: you need to make sure you have the proper permission in the execution role. Please add [AmazonFSxFullAccess](https://us-east-1.console.aws.amazon.com/iam/home#/policies/arn:aws:iam::aws:policy/AmazonFSxFullAccess) and follow the page to [Add permissions to use data repositories in Amazon S3](https://docs.aws.amazon.com/fsx/latest/LustreGuide/setting-up.html#fsx-adding-permissions-s3).*

In [None]:
fsx_client = boto3.client("fsx")

fsx_response = fsx_client.create_file_system(
 FileSystemType='LUSTRE',
 StorageCapacity=2400,
 StorageType='SSD',
 SubnetIds=[vpc_subnet_ids[0]],
 SecurityGroupIds=security_group_ids,
 LustreConfiguration={
 'ImportPath': processing_s3_output_path+"/",
 'DeploymentType': 'PERSISTENT_1',
 'PerUnitStorageThroughput': 200
 }
)


fsx_status = "CREATING"
while fsx_status == "CREATING":
 time.sleep(60)
 fsx_describe = fsx_client.describe_file_systems(
 FileSystemIds=[fsx_response["FileSystem"]["FileSystemId"]]
 )
 fsx_status = fsx_describe["FileSystems"][0]["Lifecycle"]
 print(fsx_status)

### Step 3: Launch a distributed data-parallel training job with SageMaker

First, setup file-system as input for SageMaker training

In [None]:
# Specify FSx Lustre file system id.
file_system_id = fsx_response["FileSystem"]["FileSystemId"]
# Specify FSx Lustre mount id.
fsx_mount_id = fsx_response["FileSystem"]["LustreConfiguration"]["MountName"]

# Directory path for input data on the file system. 
file_system_directory_path = f'/{fsx_mount_id}/{transformed_data_prefix}'
print(f'FSx file-system data input path:{file_system_directory_path}')

# Specify the access mode of the mount of the directory associated with the file system. 
# Directory must be mounted 'ro'(read-only).
file_system_access_mode = 'ro'

# Specify your file system type.
file_system_type = 'FSxLustre'

from sagemaker.inputs import FileSystemInput
train = FileSystemInput(
 file_system_id=file_system_id,
 file_system_type=file_system_type,
 directory_path=file_system_directory_path,
 file_system_access_mode=file_system_access_mode
)

data_channels = {'train': train}
print(data_channels)

Create SageMaker PyTorch Estimator

In [None]:
# output path: model artifacts and source code
TRAINING_JOB_DATETIME = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
output_prefix = "brats_fsx/{}/sagemaker".format(TRAINING_JOB_DATETIME)

# compute resources
instance_type = 'ml.p3.16xlarge'
instance_count = 2
world_size = instance_count * 8
num_vcpu = 64
num_workers = 16 

# network hyperparameters
hyperparameters = {'lr': 1e-4 * world_size,
 'batch_size': 4 * world_size,
 'epochs': 10,
 'num_workers': num_workers
 }

dist_config = {'smdistributed':
 {'dataparallel':{'enabled': True}}
 }

pt_estimator = get_pytorch_estimator('multi_gpu_training.py',
 hyperparameters,
 instance_type,
 instance_count,
 output_prefix,
 dist_training_config=dist_config,
 subnets=vpc_subnet_ids,
 security_group_ids=security_group_ids)



Launch training job

In [None]:
pt_estimator.fit(
 job_name='brats-2p316-fsx-64batch-16worker-{}'.format(TRAINING_JOB_DATETIME),
 inputs=data_channels,
 wait=True
)