{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST distributed training and batch transform\n", "\n", "The SageMaker Python SDK helps you deploy your models for training and hosting in optimized, production-ready containers in SageMaker. The SageMaker Python SDK is easy to use, modular, extensible and compatible with TensorFlow and MXNet. This tutorial focuses on how to create a convolutional neural network model to train the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) using TensorFlow distributed training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set up the environment\n", "\n", "First, we'll just set up a few things needed for this example" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker import get_execution_role\n", "from sagemaker.session import Session\n", "\n", "sagemaker_session = sagemaker.Session()\n", "region = sagemaker_session.boto_session.region_name\n", "\n", "role = get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download the MNIST dataset\n", "\n", "We'll now need to download the MNIST dataset, and upload it to a location in S3 after preparing for training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import utils\n", "from tensorflow.contrib.learn.python.learn.datasets import mnist\n", "import tensorflow as tf\n", "\n", "data_sets = mnist.read_data_sets('data', dtype=tf.uint8, reshape=False, validation_size=5000)\n", "\n", "utils.convert_to(data_sets.train, 'train', 'data')\n", "utils.convert_to(data_sets.validation, 'validation', 'data')\n", "utils.convert_to(data_sets.test, 'test', 'data')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Upload the data\n", "We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use this later when we start the training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = sagemaker_session.upload_data(path='data', key_prefix='data/DEMO-mnist')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Construct a script for distributed training \n", "Here is the full code for the network model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "!cat 'mnist.py'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a training job" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "from sagemaker.tensorflow import TensorFlow\n", "\n", "mnist_estimator = TensorFlow(entry_point='mnist.py',\n", " role=role,\n", " framework_version='1.11.0',\n", " training_steps=1000, \n", " evaluation_steps=100,\n", " train_instance_count=2,\n", " train_instance_type='ml.c5.4xlarge')\n", "\n", "mnist_estimator.fit(inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `fit()` method will create a training job in two ml.c4.xlarge instances. The logs above will show the instances doing training, evaluation, and incrementing the number of training steps. \n", "\n", "In the end of the training, the training job will generate a saved model for TF serving." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = mnist_estimator.create_model()\n", "response = sagemaker_session.create_model(\n", " name=mnist_estimator.latest_training_job.name,\n", " role=mnist_estimator.role,\n", " container_defs = model.prepare_container_def('ml.c4.xlarge')\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Now, let's run inference from Lambda\n", "\n", "Return to the workshop directions." ] } ], "metadata": { "kernelspec": { "display_name": "conda_tensorflow_p36", "language": "python", "name": "conda_tensorflow_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 2 }