# Optuna example with PyTorch and MNIST on Amazon SageMaker

## Setup
After you create an AWS environment by the [CloudFormation template](https://github.com/aws-samples/amazon-sagemaker-optuna-hpo-blog/blob/master/template/optuna-template.yaml), install Optuna and MySQL connector to the notebook kernel, obtain parameters from the CloudFormation Outputs, and get DB secrets from AWS Secrets Manager. Please modify the `''` to your CloudFormation stack name, which you can find at [AWS Management Console](https://us-east-1.console.aws.amazon.com/cloudformation/home?region=us-east-1#/stacks). 

In [None]:
!pip install optuna==1.4.0
!pip install PyMySQL==0.9.3

In [None]:
import boto3 # AWS Python SDK
import numpy as np
import optuna

In [None]:
# obtain parameters from CloudFormation Outputs
stack_name = 'optuna-blog'

client = boto3.client('cloudformation')
outputs = client.describe_stacks(StackName=stack_name)['Stacks'][0]['Outputs']

host = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'ClusterEndpoint'][0].split(':')[0]
db_name = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'DatabaseName'][0]
secret_name = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'DBSecretArn'][0].split(':')[-1].split('-')[0]

subnets = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'PrivateSubnets'][0].split(',')
security_group_ids = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'SageMakerSecurityGroup'][0].split(',')

In [None]:
# Call AWS Secrets Manager
from src.secrets import get_secret
region_name = boto3.session.Session().region_name
secret = get_secret(secret_name, region_name)

# PyMySQL https://docs.sqlalchemy.org/en/13/dialects/mysql.html#module-sqlalchemy.dialects.mysql.pymysql 
db = 'mysql+pymysql://{}:{}@{}/{}'.format(secret['username'], secret['password'], host, db_name)

In [None]:
# Setup
from sagemaker import get_execution_role
import sagemaker

sagemaker_session = sagemaker.Session()

# This role retrieves the SageMaker-compatible role used by this notebook instance.
role = get_execution_role()

## Train
We demonstrate an Optuna example [`pytorch_simple.py`](https://github.com/optuna/optuna/blob/master/examples/pytorch_simple.py) migrated to Amazon SageMaker. First, put the data to Amazon S3. Then, create a [PyTorch estimator](https://sagemaker.readthedocs.io/en/stable/sagemaker.pytorch.html#pytorch-estimator). The training will be invoked by the `fit` method (in parallel here). 

In [None]:
# create study in RDS/Aurora
study_name = 'pytorch-simple'
optuna.study.create_study(storage=db, study_name=study_name, direction='maximize', load_if_exists=True)

In [None]:
# data preparation 
import os 
from torchvision import datasets
from torchvision import transforms

dataset = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())

In [None]:
input_data = sagemaker_session.upload_data(path='data',key_prefix='example/pytorch_mnist')

In [None]:
# setup SageMaker PyTorch estimator
from sagemaker.pytorch.estimator import PyTorch

pytorch_estimator = PyTorch(entry_point='pytorch_simple.py',
 source_dir="src",
 framework_version='1.5.0', 
 py_version='py3', 
 role=role,
 subnets=subnets,
 security_group_ids=security_group_ids,
 instance_count=1,
 instance_type='ml.m5.xlarge',
 hyperparameters={
 'host': host, 
 'db-name': db_name, 
 'db-secret': secret_name, 
 'study-name': study_name, 
 'region-name': region_name, 
 'n-trials': 25
 })

In [None]:
# HPO in parallel
max_parallel_jobs = 4

for j in range(max_parallel_jobs-1):
 pytorch_estimator.fit(input_data, wait=False)
pytorch_estimator.fit(input_data)

In [None]:
# obtain results
study = optuna.study.load_study(study_name=study_name, storage=db)

df = study.trials_dataframe()

# optuna.visualization.plot_intermediate_values(study)
# optuna.visualization.plot_optimization_history(study)
ax = df['value'].plot()
ax.set_xlabel('Number of trials')
ax.set_ylabel('Validation accuracy')

## Deploy
Create an API endopint for inference with the best model we explored in the HPO. 

In [None]:
from sagemaker.pytorch import PyTorchModel

best_model_data = os.path.join(pytorch_estimator.output_path, study.best_trial.user_attrs['job_name'], 'output/model.tar.gz')
best_model = PyTorchModel(model_data=best_model_data, 
 role=role,
 entry_point='pytorch_simple.py', 
 source_dir="src", 
 framework_version='1.5.0', 
 py_version='py3'
 )

predictor = best_model.deploy(instance_type="ml.m5.xlarge", initial_instance_count=1)

In [None]:
import torch
test_loader = torch.utils.data.DataLoader(
 datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
 batch_size=5,
 shuffle=True,
 )

In [None]:
for batch_idx, (data, target) in enumerate(test_loader):
 data, target = data.view(-1, 28 * 28).to('cpu'), target.to('cpu')
 
 prediction = predictor.predict(data)
 predicted_label = prediction.argmax(axis=1)
 print('Pred label: {}'.format(predicted_label))
 print('True label: {}'.format(target.numpy()))
 break

### Cleanup
Delete the API endpoint. 

In [None]:
predictor.delete_endpoint()