{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch Training and using checkpointing on SageMaker Managed Spot Training\n", "The example here is almost the same as [PyTorch Cifar10 local training](https://github.com/aws/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/pytorch_cnn_cifar10/pytorch_local_mode_cifar10.ipynb).\n", "\n", "This notebook tackles the exact same problem with the same solution, but it has been modified to be able to run using SageMaker Managed Spot infrastructure. SageMaker Managed Spot uses [EC2 Spot Instances](https://aws.amazon.com/ec2/spot/) to run Training at a lower cost.\n", "\n", "Please read the original notebook and try it out to gain an understanding of the ML use-case and how it is being solved. We will not delve into that here in this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview\n", "\n", "The **SageMaker Python SDK** helps you deploy your models for training and hosting in optimized, productions ready containers in SageMaker. The SageMaker Python SDK is easy to use, modular, extensible and compatible with TensorFlow, MXNet, PyTorch and Chainer. This tutorial focuses on how to create a convolutional neural network model to train the [Cifar10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) using **PyTorch in local mode**.\n", "\n", "### Set up the environment\n", "\n", "This notebook was created and tested on a single ml.p2.xlarge notebook instance.\n", "\n", "Let's start by specifying:\n", "\n", "- The S3 bucket and prefix that you want to use for training and model data. This should be within the same region as the Notebook Instance, training, and hosting.\n", "- The IAM role arn used to give training and hosting access to your data. See the documentation for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the sagemaker.get_execution_role() with appropriate full IAM role arn string(s)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "import uuid\n", "\n", "sagemaker_session = sagemaker.Session()\n", "print('SageMaker version: ' + sagemaker.__version__)\n", "\n", "bucket = sagemaker_session.default_bucket()\n", "prefix = 'sagemaker/DEMO-pytorch-cnn-cifar10'\n", "\n", "role = sagemaker.get_execution_role()\n", "checkpoint_suffix = str(uuid.uuid4())[:8]\n", "checkpoint_s3_path = 's3://{}/checkpoint-{}'.format(bucket, checkpoint_suffix)\n", "\n", "print('Checkpointing Path: {}'.format(checkpoint_s3_path))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import subprocess\n", "\n", "instance_type = 'local'\n", "\n", "if subprocess.call('nvidia-smi') == 0:\n", " ## Set type to GPU if one is present\n", " instance_type = 'local_gpu'\n", " \n", "print(\"Instance type = \" + instance_type)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download the Cifar10 dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from utils_cifar import get_train_data_loader, get_test_data_loader, imshow, classes\n", "\n", "trainloader = get_train_data_loader()\n", "testloader = get_test_data_loader()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data Preview" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torchvision, torch\n", "\n", "# get some random training images\n", "dataiter = iter(trainloader)\n", "images, labels = dataiter.next()\n", "\n", "# show images\n", "imshow(torchvision.utils.make_grid(images))\n", "\n", "# print labels\n", "print(' '.join('%9s' % classes[labels[j]] for j in range(4)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Upload the data\n", "We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use this later when we start the training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix='data/cifar10')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Construct a script for training \n", "Here is the full code for the network model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pygmentize source_dir/cifar10.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Script Functions\n", "\n", "SageMaker invokes the main function defined within your training script for training. When deploying your trained model to an endpoint, the model_fn() is called to determine how to load your trained model. The model_fn() along with a few other functions list below are called to enable predictions on SageMaker.\n", "\n", "### [Predicting Functions](https://github.com/aws/sagemaker-pytorch-containers/blob/master/src/sagemaker_pytorch_container/serving.py)\n", "* model_fn(model_dir) - loads your model.\n", "* input_fn(serialized_input_data, content_type) - deserializes predictions to predict_fn.\n", "* output_fn(prediction_output, accept) - serializes predictions from predict_fn.\n", "* predict_fn(input_data, model) - calls a model on data deserialized in input_fn.\n", "\n", "The model_fn() is the only function that doesn't have a default implementation and is required by the user for using PyTorch on SageMaker. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a training job using the sagemaker.PyTorch estimator\n", "\n", "The `PyTorch` class allows us to run our training function on SageMaker. We need to configure it with our training script, an IAM role, the number of training instances, and the training instance type. For local training with GPU, we could set this to \"local_gpu\". In this case, `instance_type` was set above based on your whether you're running a GPU instance.\n", "\n", "After we've constructed our `PyTorch` object, we fit it using the data we uploaded to S3. Even though we're in local mode, using S3 as our data source makes sense because it maintains consistency with how SageMaker's distributed, managed training ingests data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorch\n", "\n", "hyperparameters = {'epochs': 2}\n", "\n", "cifar10_estimator = PyTorch(entry_point='source_dir/cifar10.py',\n", " role=role,\n", " framework_version='1.7.1',\n", " py_version='py3',\n", " hyperparameters=hyperparameters,\n", " instance_count=1,\n", " instance_type=instance_type)\n", "\n", "cifar10_estimator.fit(inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run a baseline training job on SageMaker\n", "\n", "Now we run training jobs on SageMaker, starting with our baseline training job.\n", "\n", "Once again, we create a PyTorch estimator, with a couple key modfications from last time:\n", "\n", "* `instance_type`: the instance type for training. We set this to `ml.p3.2xlarge` because we are training on SageMaker now. For a list of available instance types, see [the AWS documentation](https://aws.amazon.com/sagemaker/pricing/instance-types).\n", "* `metric_definitions`: the metrics (defined above) that we want sent to CloudWatch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorch\n", "\n", "hyperparameters = {'epochs': 10}\n", "\n", "cifar10_estimator = PyTorch(entry_point='source_dir/cifar10.py',\n", " role=role,\n", " framework_version='1.7.1',\n", " py_version='py3',\n", " hyperparameters=hyperparameters,\n", " instance_count=1,\n", " instance_type='ml.p3.2xlarge',\n", " base_job_name='cifar10-pytorch')\n", "\n", "cifar10_estimator.fit(inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Managed Spot Training with a PyTorch Estimator\n", "\n", "For Managed Spot Training using a PyTorch Estimator we need to configure two things:\n", "1. Enable the `train_use_spot_instances` constructor arg - a simple self-explanatory boolean.\n", "2. Set the `train_max_wait` constructor arg - this is an int arg representing the amount of time you are willing to wait for Spot infrastructure to become available. Some instance types are harder to get at Spot prices and you may have to wait longer. You are not charged for time spent waiting for Spot infrastructure to become available, you're only charged for actual compute time spent once Spot instances have been successfully procured.\n", "\n", "Normally, a third requirement would also be necessary here - modifying your code to ensure a regular checkpointing cadence - however, PyTorch Estimators already do this, so no changes are necessary here. Checkpointing is highly recommended for Manage Spot Training jobs due to the fact that Spot instances can be interrupted with short notice and using checkpoints to resume from the last interruption ensures you don't lose any progress made before the interruption.\n", "\n", "Feel free to toggle the `use_spot_instances` variable to see the effect of running the same job using regular (a.k.a. \"On Demand\") infrastructure.\n", "\n", "Note that `max_wait` can be set if and only if `use_spot_instances` is enabled and **must** be greater than or equal to `max_run`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "use_spot_instances = True\n", "max_run=600\n", "max_wait = 1200 if use_spot_instances else None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulating Spot interruption after 5 epochs\n", "\n", "Our training job should run on 10 epochs.\n", "\n", "However, we will simulate a situation that after 5 epochs a spot interruption occurred.\n", "\n", "The goal is that the checkpointing data will be copied to S3, so when there is a spot capacity available again, the training job can resume from the 6th epoch.\n", "\n", "Note the `checkpoint_s3_uri` variable which stores the S3 URI in which to persist checkpoints that the algorithm persists (if any) during training.\n", "\n", "The `debugger_hook_config` parameter must be set to `False` to enable checkpoints to be copied to S3 successfully." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparameters = {'epochs': 5}\n", "\n", "\n", "spot_estimator = PyTorch(entry_point='source_dir/cifar10.py',\n", " role=role,\n", " framework_version='1.7.1',\n", " py_version='py3',\n", " instance_count=1,\n", " instance_type='ml.p3.2xlarge',\n", " base_job_name='cifar10-pytorch-spot-1',\n", " hyperparameters=hyperparameters,\n", " checkpoint_s3_uri=checkpoint_s3_path,\n", " debugger_hook_config=False,\n", " use_spot_instances=use_spot_instances,\n", " max_run=max_run,\n", " max_wait=max_wait)\n", "\n", "spot_estimator.fit(inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Savings\n", "Towards the end of the job you should see two lines of output printed:\n", "\n", "- `Training seconds: X` : This is the actual compute-time your training job spent\n", "- `Billable seconds: Y` : This is the time you will be billed for after Spot discounting is applied.\n", "\n", "If you enabled the `use_spot_instances` var then you should see a notable difference between `X` and `Y` signifying the cost savings you will get for having chosen Managed Spot Training. This should be reflected in an additional line:\n", "- `Managed Spot Training savings: (1-Y/X)*100 %`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### View the job training Checkpoint configuration\n", "We can now view the Checkpoint configuration from the training job directly in the SageMaker console.\n", "\n", "Log into the [SageMaker console](https://console.aws.amazon.com/sagemaker/home), choose the latest training job, and scroll down to the Checkpoint configuration section. \n", "\n", "Choose the S3 output path link and you'll be directed to the S3 bucket were checkpointing data is saved.\n", "\n", "You can see there is one file there:\n", "\n", "```python\n", "checkpoint.pth\n", "```\n", "\n", "This is the checkpoint file that contains the epoch, model state dict, optimizer state dict, and loss." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Continue training after Spot capacity is resumed\n", "\n", "Now we simulate a situation where Spot capacity is resumed.\n", "\n", "We will start a training job again, this time with 10 epochs.\n", "\n", "What we expect is that the tarining job will start from the 6th epoch.\n", "\n", "This is done when training job starts. It checks the checkpoint s3 location for checkpoints data. If there are, they are copied to `/opt/ml/checkpoints` on the training conatiner.\n", "\n", "In the code you can see the function to load the checkpoints data:\n", "\n", "```python\n", "def _load_checkpoint(model, optimizer, args):\n", " print(\"--------------------------------------------\")\n", " print(\"Checkpoint file found!\")\n", " print(\"Loading Checkpoint From: {}\".format(args.checkpoint_path + '/checkpoint.pth'))\n", " checkpoint = torch.load(args.checkpoint_path + '/checkpoint.pth')\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " epoch_number = checkpoint['epoch']\n", " loss = checkpoint['loss']\n", " print(\"Checkpoint File Loaded - epoch_number: {} - loss: {}\".format(epoch_number, loss))\n", " print('Resuming training from epoch: {}'.format(epoch_number+1))\n", " print(\"--------------------------------------------\")\n", " return model, optimizer, epoch_number\n", "```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparameters = {'epochs': 10}\n", "\n", "\n", "spot_estimator = PyTorch(entry_point='source_dir/cifar10.py',\n", " role=role,\n", " framework_version='1.7.1',\n", " py_version='py3',\n", " instance_count=1,\n", " instance_type='ml.p3.2xlarge',\n", " base_job_name='cifar10-pytorch-spot-2',\n", " hyperparameters=hyperparameters,\n", " checkpoint_s3_uri=checkpoint_s3_path,\n", " debugger_hook_config=False,\n", " use_spot_instances=use_spot_instances,\n", " max_run=max_run,\n", " max_wait=max_wait)\n", "\n", "spot_estimator.fit(inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Analyze training job logs\n", "\n", "Analyzing the training job logs, we can see that now, the training job starts from the 6th epoch.\n", "\n", "We can see the output of `_load_checkpoint` function:\n", "\n", "```\n", "--------------------------------------------\n", "Checkpoint file found!\n", "Loading Checkpoint From: /opt/ml/checkpoints/checkpoint.pth\n", "Checkpoint File Loaded - epoch_number: 5 - loss: 0.8455273509025574\n", "Resuming training from epoch: 6\n", "--------------------------------------------\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### View the job training Checkpoint configuration after job completed 10 epochs\n", "\n", "We can now view the Checkpoint configuration from the training job directly in the SageMaker console. \n", "\n", "Log into the [SageMaker console](https://console.aws.amazon.com/sagemaker/home), choose the latest training job, and scroll down to the Checkpoint configuration section. \n", "\n", "Choose the S3 output path link and you'll be directed to the S3 bucket were checkpointing data is saved.\n", "\n", "You can see there is still that one file there:\n", "\n", "```python\n", "checkpoint.pth\n", "```\n", "\n", "You'll be able to see that the date of the checkpoint file was updated to the time of the 2nd Spot training job." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Deploy the trained model to prepare for predictions\n", "\n", "The deploy() method creates an endpoint which serves prediction requests in real-time." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorchModel\n", "\n", "predictor = spot_estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Invoking the endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# get some test images\n", "dataiter = iter(testloader)\n", "images, labels = dataiter.next()\n", "\n", "# print images\n", "imshow(torchvision.utils.make_grid(images))\n", "print('GroundTruth: ', ' '.join('%4s' % classes[labels[j]] for j in range(4)))\n", "\n", "outputs = predictor.predict(images.numpy())\n", "\n", "_, predicted = torch.max(torch.from_numpy(np.array(outputs)), 1)\n", "\n", "print('Predicted: ', ' '.join('%4s' % classes[predicted[j]]\n", " for j in range(4)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Clean-up\n", "\n", "To avoid incurring extra charges to your AWS account, let's delete the endpoint we created:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p36", "language": "python", "name": "conda_pytorch_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" }, "notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 2 }