# Optuna example with Chainer MNIST on SageMaker

## Setup
After you create an AWS environment by the [CloudFormation template](https://github.com/aws-samples/amazon-sagemaker-optuna-hpo-blog/blob/master/template/optuna-template.yaml), install Optuna and MySQL connector to the notebook kernel, obtain parameters from the CloudFormation Outputs, and get DB secrets from AWS Secrets Manager. Please modify the `'<your_cfn_stack_name>'` to your CloudFormation stack name, which you can find at [AWS Management Console](https://us-east-1.console.aws.amazon.com/cloudformation/home?region=us-east-1#/stacks). 

In [None]:
!pip install optuna
!pip install mysql-connector-python

In [None]:
import boto3 # AWS Python SDK
import numpy as np
import optuna

In [None]:
# obtain parameters from CloudFormation Outputs
stack_name = '<your_cfn_stack_name>'

client = boto3.client('cloudformation')
outputs = client.describe_stacks(StackName=stack_name)['Stacks'][0]['Outputs']

host = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'ClusterEndpoint'][0].split(':')[0]
db_name = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'DatabaseName'][0]
secret_name = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'DBSecretArn'][0].split(':')[-1].split('-')[0]

subnets = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'PrivateSubnets'][0].split(',')
security_group_ids = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'SageMakerSecurityGroup'][0].split(',')

In [None]:
# Call AWS Secrets Manager
from src.secrets import get_secret
region_name = boto3.session.Session().region_name
secret = get_secret(secret_name, region_name)

# MySQL-connector-python    
db = 'mysql+mysqlconnector://{}:{}@{}/{}'.format(secret['username'], secret['password'], host, db_name)

In [None]:
# Setup
from sagemaker import get_execution_role
import sagemaker

sagemaker_session = sagemaker.Session()

# This role retrieves the SageMaker-compatible role used by this notebook instance.
role = get_execution_role()

## Train
We demonstrate an Optuna example [`chainer_simple.py`](https://github.com/pfnet/optuna/blob/master/examples/chainer_simple.py) migrated to Amazon SageMaker. First, put the data to Amazon S3. Then, create a [Chainer estimator](https://sagemaker.readthedocs.io/en/stable/sagemaker.chainer.html#sagemaker.chainer.estimator.Chainer). The training will be invoked by the `fit` method (in parallel here). 

In [None]:
# create study in RDS/Aurora
study_name = 'chainer-simple'
optuna.study.create_study(storage=db, study_name=study_name, direction='maximize', load_if_exists=True)

In [None]:
# prepare data
import chainer
import os
import shutil
import numpy as np

N_TRAIN_EXAMPLES = 3000
N_TEST_EXAMPLES = 1000

rng = np.random.RandomState(0)
train, test = chainer.datasets.get_mnist()

train = chainer.datasets.SubDataset(
    train, 0, N_TRAIN_EXAMPLES, order=rng.permutation(len(train)))
test = chainer.datasets.SubDataset(
    test, 0, N_TEST_EXAMPLES, order=rng.permutation(len(test)))

In [None]:
train_data = np.array([element[0] for element in train])
train_labels = np.array([element[1] for element in train])

test_data = np.array([element[0] for element in test])
test_labels = np.array([element[1] for element in test])

In [None]:
# upload to Amazon S3
try:
    os.makedirs('/tmp/data/train_mnist')
    os.makedirs('/tmp/data/test_mnist')
    np.savez('/tmp/data/train_mnist/train.npz', data=train_data, labels=train_labels)
    np.savez('/tmp/data/test_mnist/test.npz', data=test_data, labels=test_labels)
    train_input = sagemaker_session.upload_data(
                      path=os.path.join('/tmp', 'data', 'train_mnist'),
                      key_prefix='notebook/chainer_mnist/train')
    test_input = sagemaker_session.upload_data(
                      path=os.path.join('/tmp', 'data', 'test_mnist'),
                      key_prefix='notebook/chainer_mnist/test')
finally:
    shutil.rmtree('/tmp/data')
print('training data at %s' % train_input)
print('test data at %s' % test_input)

In [None]:
# setup SageMaker Chainer estimator
from sagemaker.chainer.estimator import Chainer

chainer_estimator = Chainer(entry_point='chainer_simple.py',
                            source_dir="src",
                            framework_version='5.0.0', 
                            role=role,
                            sagemaker_session=sagemaker_session,
                            subnets=subnets,
                            security_group_ids=security_group_ids,
                            train_instance_count=1,
                            train_instance_type='ml.c5.xlarge',
                            hyperparameters={
                                'host': host, 
                                'db-name': db_name, 
                                'db-secret': secret_name, 
                                'study-name': study_name, 
                                'n-trials': 25, 
                                'region-name': region_name
                            })

In [None]:
# HPO in parallel
max_parallel_jobs = 4

for j in range(max_parallel_jobs-1):
    chainer_estimator.fit({'train': train_input, 'test': test_input}, wait=False)
chainer_estimator.fit({'train': train_input, 'test': test_input})

In [None]:
# obtain results
study = optuna.study.load_study(study_name=study_name, storage=db)

df = study.trials_dataframe()

# optuna.visualization.plot_intermediate_values(study)
ax = df['user_attrs']['validation/main/accuracy'].plot()
ax.set_xlabel('Number of trials')
ax.set_ylabel('Validation accuracy')

## Deploy
Create an API endopint for inference with the best model we explored in the HPO. 

In [None]:
from sagemaker.chainer import ChainerModel

best_model_data = os.path.join(chainer_estimator.output_path, study.best_trial.user_attrs['job_name'], 'output/model.tar.gz')
best_model = ChainerModel(model_data=best_model_data, 
                          role=role,
                          entry_point='chainer_simple.py', 
                          source_dir="src")

predictor = best_model.deploy(instance_type="ml.m4.xlarge", initial_instance_count=1)

In [None]:
import random

import matplotlib.pyplot as plt

num_samples = 5
indices = random.sample(range(test_data.shape[0] - 1), num_samples)
images, labels = test_data[indices], test_labels[indices]

for i in range(num_samples):
    plt.subplot(1,num_samples,i+1)
    plt.imshow(images[i].reshape(28, 28), cmap='gray')
    plt.title(labels[i])
    plt.axis('off')

In [None]:
prediction = predictor.predict(images)
predicted_label = prediction.argmax(axis=1)
print('The predicted labels are: {}'.format(predicted_label))

### Cleanup
Delete the API endpoint. 

In [None]:
predictor.delete_endpoint()