{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Optuna example with PyTorch and MNIST on Amazon SageMaker" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "After you create an AWS environment by the [CloudFormation template](https://github.com/aws-samples/amazon-sagemaker-optuna-hpo-blog/blob/master/template/optuna-template.yaml), install Optuna and MySQL connector to the notebook kernel, obtain parameters from the CloudFormation Outputs, and get DB secrets from AWS Secrets Manager. Please modify the `''` to your CloudFormation stack name, which you can find at [AWS Management Console](https://us-east-1.console.aws.amazon.com/cloudformation/home?region=us-east-1#/stacks). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install optuna==1.4.0\n", "!pip install PyMySQL==0.9.3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import boto3 # AWS Python SDK\n", "import numpy as np\n", "import optuna" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# obtain parameters from CloudFormation Outputs\n", "stack_name = 'optuna-blog'\n", "\n", "client = boto3.client('cloudformation')\n", "outputs = client.describe_stacks(StackName=stack_name)['Stacks'][0]['Outputs']\n", "\n", "host = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'ClusterEndpoint'][0].split(':')[0]\n", "db_name = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'DatabaseName'][0]\n", "secret_name = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'DBSecretArn'][0].split(':')[-1].split('-')[0]\n", "\n", "subnets = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'PrivateSubnets'][0].split(',')\n", "security_group_ids = [out['OutputValue'] for out in outputs if out['OutputKey'] == 'SageMakerSecurityGroup'][0].split(',')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Call AWS Secrets Manager\n", "from src.secrets import get_secret\n", "region_name = boto3.session.Session().region_name\n", "secret = get_secret(secret_name, region_name)\n", "\n", "# PyMySQL https://docs.sqlalchemy.org/en/13/dialects/mysql.html#module-sqlalchemy.dialects.mysql.pymysql \n", "db = 'mysql+pymysql://{}:{}@{}/{}'.format(secret['username'], secret['password'], host, db_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Setup\n", "from sagemaker import get_execution_role\n", "import sagemaker\n", "\n", "sagemaker_session = sagemaker.Session()\n", "\n", "# This role retrieves the SageMaker-compatible role used by this notebook instance.\n", "role = get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train\n", "We demonstrate an Optuna example [`pytorch_simple.py`](https://github.com/optuna/optuna/blob/master/examples/pytorch_simple.py) migrated to Amazon SageMaker. First, put the data to Amazon S3. Then, create a [PyTorch estimator](https://sagemaker.readthedocs.io/en/stable/sagemaker.pytorch.html#pytorch-estimator). The training will be invoked by the `fit` method (in parallel here). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create study in RDS/Aurora\n", "study_name = 'pytorch-simple'\n", "optuna.study.create_study(storage=db, study_name=study_name, direction='maximize', load_if_exists=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# data preparation \n", "import os \n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "dataset = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_data = sagemaker_session.upload_data(path='data',key_prefix='example/pytorch_mnist')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# setup SageMaker PyTorch estimator\n", "from sagemaker.pytorch.estimator import PyTorch\n", "\n", "pytorch_estimator = PyTorch(entry_point='pytorch_simple.py',\n", " source_dir=\"src\",\n", " framework_version='1.5.0', \n", " py_version='py3', \n", " role=role,\n", " subnets=subnets,\n", " security_group_ids=security_group_ids,\n", " instance_count=1,\n", " instance_type='ml.m5.xlarge',\n", " hyperparameters={\n", " 'host': host, \n", " 'db-name': db_name, \n", " 'db-secret': secret_name, \n", " 'study-name': study_name, \n", " 'region-name': region_name, \n", " 'n-trials': 25\n", " })" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# HPO in parallel\n", "max_parallel_jobs = 4\n", "\n", "for j in range(max_parallel_jobs-1):\n", " pytorch_estimator.fit(input_data, wait=False)\n", "pytorch_estimator.fit(input_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# obtain results\n", "study = optuna.study.load_study(study_name=study_name, storage=db)\n", "\n", "df = study.trials_dataframe()\n", "\n", "# optuna.visualization.plot_intermediate_values(study)\n", "# optuna.visualization.plot_optimization_history(study)\n", "ax = df['value'].plot()\n", "ax.set_xlabel('Number of trials')\n", "ax.set_ylabel('Validation accuracy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deploy\n", "Create an API endopint for inference with the best model we explored in the HPO. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorchModel\n", "\n", "best_model_data = os.path.join(pytorch_estimator.output_path, study.best_trial.user_attrs['job_name'], 'output/model.tar.gz')\n", "best_model = PyTorchModel(model_data=best_model_data, \n", " role=role,\n", " entry_point='pytorch_simple.py', \n", " source_dir=\"src\", \n", " framework_version='1.5.0', \n", " py_version='py3'\n", " )\n", "\n", "predictor = best_model.deploy(instance_type=\"ml.m5.xlarge\", initial_instance_count=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "test_loader = torch.utils.data.DataLoader(\n", " datasets.MNIST('data', train=False, transform=transforms.ToTensor()),\n", " batch_size=5,\n", " shuffle=True,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for batch_idx, (data, target) in enumerate(test_loader):\n", " data, target = data.view(-1, 28 * 28).to('cpu'), target.to('cpu')\n", " \n", " prediction = predictor.predict(data)\n", " predicted_label = prediction.argmax(axis=1)\n", " print('Pred label: {}'.format(predicted_label))\n", " print('True label: {}'.format(target.numpy()))\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cleanup\n", "Delete the API endpoint. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor.delete_endpoint()" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p36", "language": "python", "name": "conda_pytorch_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }