In [None]:
from util_debugger import get_sys_metric
from util_train import aug_exp_train
import pprint

### Setting up required parameters for the experiment to compare training time for augmentation on CPU and GPU

In [None]:
curr_sm_role = # Please update your SageMaker Execution Role
 
# 'pytorch-cpu': JPEG decoding and augmentation on CPUs using PyTorch Dataloader
# 'dali-cpu': JPEG decoding and augmentation on CPUs using NVIDIA DALI
# 'dali-gpu': JPEG decoding and augmentation on GPUs using NVIDIA DALI
AUGMENTATION_APPROACHES = ['pytorch-cpu', 'dali-gpu']

instance_type = 'ml.p3.2xlarge'
# Required for plotting system utilization
num_cpu = 8
num_gpu = 1

# Training Script supports: 'RESNET50', 'RESNET18', and 'RESNET152'
model_arch = 'RESNET18'

batch_size = 32

# Factor by which to repeat augmentation operations for increasing data pre-processing load 
aug_load_factor = 12

# You can change other parameters such as training data, S3 bucket, Epoch and training hyperparameters at util_train.script

### Launching training jobs and fetching system utilization for data pre-processing on CPUs vs on GPUs

In [None]:
exp_data = {}
trial = 0
pp = pprint.PrettyPrinter()

for aug_operator in AUGMENTATION_APPROACHES:
 
 trial = trial + 1
 trial_data = dict.fromkeys(['train_job_id', 'model_arch', 'instance_type', 'batch_size', 'aug_load_factor', 'aug_operator', 'sys_util_df'])
 
 # Launch Amazon Sagemaker PyTorch traininng jobs with your custom training script.
 train_job_id, train_estimator = aug_exp_train(model_arch, 
 batch_size, 
 aug_operator, 
 aug_load_factor, 
 instance_type, 
 curr_sm_role)
 
 # Extract system utilization metrics with SageMaker Debugger
 heatmap, metric_hist, sys_util_df = get_sys_metric(train_estimator, 
 num_cpu,
 num_gpu)
 
 # Print parameter and result summary for the current training job run
 trial_data['train_job_id'] = train_job_id
 trial_data['model_arch'] = model_arch
 trial_data['instance_type'] = instance_type
 trial_data['batch_size'] = batch_size
 trial_data['aug_load_factor'] = aug_load_factor
 trial_data['aug_operator'] = aug_operator
 trial_data['sys_util_df'] = sys_util_df
 
 pp.pprint(trial_data) 
 exp_data.update({'trial-'+str(trial): trial_data})
 
pp.pprint(exp_data) 