{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Apache MXNet Training and using checkpointing on SageMaker Managed Spot Training\n", "\n", "The example here is almost the same as [Training and hosting SageMaker Models using the Apache MXNet Module API](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_mnist/mxnet_mnist.ipynb).\n", "\n", "This notebook tackles the exact same problem with the same solution, but it has been modified to be able to run using SageMaker Managed Spot infrastructure. SageMaker Managed Spot uses [EC2 Spot Instances](https://aws.amazon.com/ec2/spot/) to run Training at a lower cost.\n", "\n", "Please read the original notebook and try it out to gain an understanding of the ML use-case and how it is being solved. We will not delve into that here in this notebook.\n", "\n", "## First setup variables and define functions\n", "\n", "Again, we won't go into detail explaining the code below, it has been lifted verbatim from [Training and hosting SageMaker Models using the Apache MXNet Module API](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_mnist/mxnet_mnist.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install sagemaker -U" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "isConfigCell": true }, "outputs": [], "source": [ "import sagemaker\n", "import boto3\n", "from sagemaker import get_execution_role\n", "from sagemaker.session import Session\n", "import uuid\n", "\n", "# S3 bucket for saving code and model artifacts.\n", "# Feel free to specify a different bucket here if you wish.\n", "bucket = Session().default_bucket()\n", "\n", "# Location to save your custom code in tar.gz format.\n", "custom_code_upload_location = 's3://{}/customcode/mxnet'.format(bucket)\n", "\n", "# Location where results of model training are saved.\n", "model_artifacts_location = 's3://{}/artifacts'.format(bucket)\n", "\n", "# IAM execution role that gives SageMaker access to resources in your AWS account.\n", "# We can use the SageMaker Python SDK to get the role from our notebook environment. \n", "role = get_execution_role()\n", "\n", "region = boto3.Session().region_name\n", "train_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/train'.format(region)\n", "test_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/test'.format(region)\n", "\n", "checkpoint_suffix = str(uuid.uuid4())[:8]\n", "checkpoint_s3_uri = 's3://{}/artifacts/mxnet-checkpoint-{}/'.format(bucket, checkpoint_suffix)\n", "\n", "print('SageMaker version: ' + sagemaker.__version__)\n", "print('Checkpointing Path: {}'.format(checkpoint_s3_uri))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Managed Spot Training with MXNet\n", "\n", "For Managed Spot Training using MXNet we need to configure three things:\n", "1. Enable the `use_spot_instances` constructor arg - a simple self-explanatory boolean.\n", "2. Set the `max_wait` constructor arg - this is an int arg representing the amount of time you are willing to wait for Spot infrastructure to become available. Some instance types are harder to get at Spot prices and you may have to wait longer. You are not charged for time spent waiting for Spot infrastructure to become available, you're only charged for actual compute time spent once Spot instances have been successfully procured.\n", "3. Setup a `checkpoint_s3_uri` constructor arg. This arg will tell SageMaker an S3 location where to save checkpoints (assuming your algorithm has been modified to save checkpoints periodically). While not strictly necessary checkpointing is highly recommended for Manage Spot Training jobs due to the fact that Spot instances can be interrupted with short notice and using checkpoints to resume from the last interruption ensures you don't lose any progress made before the interruption.\n", "\n", "Feel free to toggle the `use_spot_instances` variable to see the effect of running the same job using regular (a.k.a. \"On Demand\") infrastructure.\n", "\n", "Note that `max_wait` can be set if and only if `use_spot_instances` is enabled and **must** be greater than or equal to `max_run`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "use_spot_instances = True\n", "max_run=600\n", "max_wait = 1200 if use_spot_instances else None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Simulating Spot interruption after 5 epochs\n", "\n", "Our training job should run on 10 epochs.\n", "\n", "However, we will simulate a situation that after 5 epochs a spot interruption occurred.\n", "\n", "The goal is that the checkpointing data will be copied to S3, so when there is a spot capacity available again, the training job can resume from the 6th epoch.\n", "\n", "Note the `checkpoint_s3_uri` variable which stores the S3 URI in which to persist checkpoints that the algorithm persists (if any) during training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.mxnet import MXNet\n", "\n", "mnist_estimator = MXNet(entry_point='source_dir/mnist.py',\n", " role=role,\n", " output_path=model_artifacts_location,\n", " code_location=custom_code_upload_location,\n", " instance_count=1,\n", " instance_type='ml.m4.xlarge',\n", " framework_version='1.6.0',\n", " py_version='py3',\n", " distribution={'parameter_server': {'enabled': True}},\n", " hyperparameters={'learning-rate': 0.1, 'epochs': 5},\n", " use_spot_instances=use_spot_instances,\n", " max_run=max_run,\n", " max_wait=max_wait,\n", " checkpoint_s3_uri=checkpoint_s3_uri)\n", "mnist_estimator.fit({'train': train_data_location, 'test': test_data_location})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Savings\n", "Towards the end of the job you should see two lines of output printed:\n", "\n", "- `Training seconds: X` : This is the actual compute-time your training job spent\n", "- `Billable seconds: Y` : This is the time you will be billed for after Spot discounting is applied.\n", "\n", "If you enabled the `use_spot_instances` var then you should see a notable difference between `X` and `Y` signifying the cost savings you will get for having chosen Managed Spot Training. This should be reflected in an additional line:\n", "- `Managed Spot Training savings: (1-Y/X)*100 %`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Analyze training job logs\n", "\n", "Analyzing the training job logs, we can see that the training job starts from the 1st epoch:\n", "\n", "```\n", "INFO:root:Starting training from epoch: 0\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### View the job training Checkpoint configuration\n", "\n", "We can now view the Checkpoint configuration from the training job directly in the SageMaker console. \n", "\n", "Log into the [SageMaker console](https://console.aws.amazon.com/sagemaker/home), choose the latest training job, and scroll down to the Checkpoint configuration section. \n", "\n", "Choose the S3 output path link and you'll be directed to the S3 bucket were checkpointing data is saved.\n", "\n", "You can see there are 11 files there:\n", "\n", "```python\n", "mnist-symbol.json \n", "mnist-0005.states \n", "mnist-0005.params \n", "mnist-0004.states \n", "mnist-0004.params \n", "mnist-0003.states \n", "mnist-0003.params \n", "mnist-0002.states \n", "mnist-0002.params \n", "mnist-0001.states \n", "mnist-0001.params \n", "```\n", "\n", "Those files store the trainer states, model parameters, and model architecture." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Continue training after Spot capacity is resumed\n", "\n", "Now we simulate a situation where Spot capacity is resumed.\n", "\n", "We will start a training job again, this time with 10 epochs.\n", "\n", "What we expect is that the tarining job will start from the 6th epoch (will show up as epcoh 5 in MXNet logs).\n", "\n", "This is done when training job starts. It checks the checkpoint s3 location for checkpoints data. If there are, they are copied to `/opt/ml/checkpoints` on the training conatiner.\n", "\n", "In the code you can see the `load_model_from_checkpoints` function to load the checkpoints data:\n", "\n", "```python\n", "def load_model_from_checkpoints(checkpoint_path):\n", " checkpoint_files = [file for file in os.listdir(checkpoint_path) if file.endswith('.' + 'params')]\n", " logging.info('------------------------------------------------------')\n", " logging.info(\"Available checkpoint files: {}\".format(checkpoint_files))\n", " epoch_numbers = [re.search('(\\.*[0-9])(?=\\.)',file).group() for file in checkpoint_files]\n", " \n", " max_epoch_number = max(epoch_numbers)\n", " max_epoch_index = epoch_numbers.index(max_epoch_number)\n", " max_epoch_filename = checkpoint_files[max_epoch_index]\n", "\n", " logging.info('Latest epoch checkpoint file name: {}'.format(max_epoch_filename))\n", " logging.info('Resuming training from epoch: {}'.format(max_epoch_number))\n", " logging.info('------------------------------------------------------')\n", " \n", " sym, arg_params, aux_params = mx.model.load_checkpoint(checkpoint_path + \"/mnist\", int(max_epoch_number))\n", " mlp_model = mx.mod.Module(symbol=sym)\n", " return mlp_model, int(max_epoch_number)\n", "```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mnist_estimator = MXNet(entry_point='source_dir/mnist.py',\n", " role=role,\n", " output_path=model_artifacts_location,\n", " code_location=custom_code_upload_location,\n", " instance_count=1,\n", " instance_type='ml.m4.xlarge',\n", " framework_version='1.6.0',\n", " py_version='py3',\n", " distribution={'parameter_server': {'enabled': True}},\n", " hyperparameters={'learning-rate': 0.1, 'epochs': 10},\n", " use_spot_instances=use_spot_instances,\n", " max_run=max_run,\n", " max_wait=max_wait,\n", " checkpoint_s3_uri=checkpoint_s3_uri)\n", "mnist_estimator.fit({'train': train_data_location, 'test': test_data_location})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Analyze training job logs\n", "\n", "Analyzing the training job logs, we can see that now, the training job starts from the 6th epoch.\n", "\n", "We can see the output of `load_model_from_checkpoints` function:\n", "\n", "```\n", "INFO:root:------------------------------------------------------\n", "INFO:root:Available checkpoint files: ['mnist-0005.params', 'mnist-0001.params', 'mnist-0003.params', 'mnist-0004.params', 'mnist-0002.params']\n", "INFO:root:Latest epoch checkpoint file name: mnist-0005.params\n", "INFO:root:Resuming training from epoch: 5\n", "INFO:root:------------------------------------------------------\n", "```\n", "\n", "Going further down in the logs, we can now see the following line indicating training job will start from the 6th epoch (shows up as epcoh 5 in MXNet logs):\n", "```\n", "INFO:root:Starting training from epoch: 5\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### View the job training Checkpoint configuration after job completed 10 epochs\n", "\n", "We can now view the Checkpoint configuration from the training job directly in the SageMaker console. \n", "\n", "Log into the [SageMaker console](https://console.aws.amazon.com/sagemaker/home), choose the latest training job, and scroll down to the Checkpoint configuration section. \n", "\n", "Choose the S3 output path link and you'll be directed to the S3 bucket were checkpointing data is saved.\n", "\n", "You can see there are 21 files there:\n", "\n", "```python\n", "mnist-symbol.json \n", "mnist-0010.states\n", "mnist-0010.params\n", "mnist-0009.states \n", "mnist-0009.params \n", "mnist-0008.states \n", "mnist-0008.params \n", "mnist-0007.states \n", "mnist-0007.params \n", "mnist-0006.states \n", "mnist-0006.params \n", "mnist-0005.states \n", "mnist-0005.params \n", "mnist-0004.states \n", "mnist-0004.params \n", "mnist-0003.states \n", "mnist-0003.params \n", "mnist-0002.states \n", "mnist-0002.params \n", "mnist-0001.states \n", "mnist-0001.params \n", "```\n", "\n", "You'll be able to see that the dates of the first ten checkpoint files (1-5), and the second group (6-10) and mnist-symbol.json are grouped together, indicating the different time where the training job was run." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_mxnet_p36", "language": "python", "name": "conda_mxnet_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 4 }