{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST distributed training \n", "\n", "This tutorial focuses on how to create a convolutional neural network model to train the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) using **TensorFlow distributed training**.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set up the environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker as sagemakerSDK\n", "from sagemaker import get_execution_role\n", "import json\n", "import boto3\n", "import time\n", "import os\n", "import time\n", "import tarfile\n", "from botocore.exceptions import ClientError\n", "cf = boto3.client('cloudformation')\n", "s3 = boto3.client('s3')\n", "sns = boto3.client('sns')\n", "step = boto3.client('stepfunctions')\n", "sagemaker = boto3.client('sagemaker-runtime')\n", "ssm=boto3.client('ssm')\n", "cf = boto3.client('cloudformation')\n", "\n", "with open('../config.json') as json_file: \n", " config = json.load(json_file)\n", "StackName=config[\"StackName\"]\n", "\n", "result=cf.describe_stacks(\n", " StackName=StackName\n", ")\n", "outputs={}\n", "for output in result['Stacks'][0]['Outputs']:\n", " outputs[output['OutputKey']]=output['OutputValue']\n", "\n", "with tarfile.open(\"script.tar.gz\", \"w:gz\") as tar:\n", " tar.add(os.getcwd(),arcname=\"\")\n", "\n", "s3.upload_file(\"script.tar.gz\",outputs[\"CodeBucket\"],\"script.tar.gz\")\n", "#IAM execution role that gives SageMaker access to resources in your AWS account.\n", "#We can use the SageMaker Python SDK to get the role from our notebook environment. \n", "role = get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to make sure the Sagebuild template is configured correctly for TensorFlow. the following code will set the stack configuration" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params=result[\"Stacks\"][0][\"Parameters\"]\n", "for n,i in enumerate(params):\n", " if(i[\"ParameterKey\"]==\"ConfigFramework\"):\n", " i[\"ParameterValue\"]=\"TENSORFLOW\" \n", "\n", "try:\n", " cf.update_stack(\n", " StackName=StackName,\n", " UsePreviousTemplate=True,\n", " Parameters=params,\n", " Capabilities=[\n", " 'CAPABILITY_NAMED_IAM',\n", " ]\n", " )\n", " waiter = cf.get_waiter('stack_update_complete')\n", " print(\"Waiting for stack update\")\n", " waiter.wait(\n", " StackName=StackName,\n", " WaiterConfig={\n", " 'Delay':10,\n", " 'MaxAttempts':600\n", " }\n", " )\n", "\n", "except ClientError as e:\n", " if(e.response[\"Error\"][\"Message\"]==\"No updates are to be performed.\"):\n", " pass\n", " else:\n", " raise e\n", "print(\"stack ready!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download the MNIST dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "import utils\n", "from tensorflow.contrib.learn.python.learn.datasets import mnist\n", "import tensorflow as tf\n", "\n", "data_sets = mnist.read_data_sets('data', dtype=tf.uint8, reshape=False, validation_size=5000)\n", "\n", "utils.convert_to(data_sets.train, 'train', 'data')\n", "utils.convert_to(data_sets.validation, 'validation', 'data')\n", "utils.convert_to(data_sets.test, 'test', 'data')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Construct a script for distributed training \n", "Here is the full code for the network model:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The script here is and adaptation of the [TensorFlow MNIST example](https://github.com/tensorflow/models/tree/master/official/mnist). It provides a ```model_fn(features, labels, mode)```, which is used for training, evaluation and inference. \n", "\n", "## A regular ```model_fn```\n", "\n", "A regular **```model_fn```** follows the pattern:\n", "1. [defines a neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L96)\n", "- [applies the ```features``` in the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L178)\n", "- [if the ```mode``` is ```PREDICT```, returns the output from the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L186)\n", "- [calculates the loss function comparing the output with the ```labels```](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L188)\n", "- [creates an optimizer and minimizes the loss function to improve the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L193)\n", "- [returns the output, optimizer and loss function](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L205)\n", "\n", "## Writing a ```model_fn``` for distributed training\n", "When distributed training happens, the same neural network will be sent to the multiple training instances. Each instance will predict a batch of the dataset, calculate loss and minimize the optimizer. One entire loop of this process is called **training step**.\n", "\n", "### Syncronizing training steps\n", "A [global step](https://www.tensorflow.org/api_docs/python/tf/train/global_step) is a global variable shared between the instances. It's necessary for distributed training, so the optimizer will keep track of the number of **training steps** between runs: \n", "\n", "```python\n", "train_op = optimizer.minimize(loss, tf.train.get_or_create_global_step())\n", "```\n", "\n", "That is the only required change for distributed training!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "!cat 'mnist.py'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set Parameters for SageBuild" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "store=outputs[\"ParameterStore\"]\n", "result=ssm.get_parameter(Name=store)\n", "\n", "params=json.loads(result[\"Parameter\"][\"Value\"])\n", "params[\"hyperparameters\"]={\n", " 'training_steps':1000, \n", " 'evaluation_steps':100,\n", " 'sagemaker_requirements':\"\"\n", "}\n", "params[\"trainentrypoint\"]=\"mnist.py\"\n", "params[\"traininstancecount\"]=2\n", "params[\"traininstancetype\"]=\"ml.c4.xlarge\"\n", "params[\"trainsourcefile\"]=\"s3://{}/script.tar.gz\".format(outputs[\"CodeBucket\"])\n", "params[\"hostentrypoint\"]=\"mnist.py\"\n", "params[\"hostsourcefile\"]=\"s3://{}/script.tar.gz\".format(outputs[\"CodeBucket\"])\n", "params[\"pyversion\"]=\"py2\"\n", "params[\"channels\"]={\n", " \"training\":{\n", " \"path\":\"train/mnist-dist\",\n", " \"dist\":False\n", " }\n", "}\n", "\n", "ssm.put_parameter(\n", " Name=store,\n", " Type=\"String\",\n", " Overwrite=True,\n", " Value=json.dumps(params)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Upload the data\n", "We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use this later when we start the training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataBucket=outputs[\"DataBucket\"]\n", "!aws s3 cp ./data s3://$dataBucket/train/mnist-dist --recursive" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "%%time\n", "result=sns.publish(\n", " TopicArn=outputs['LaunchTopic'],\n", " Message=\"{}\" #message is not important, just publishing to topic starts build\n", ")\n", "print(result)\n", "time.sleep(5)\n", "#list all executions for our StateMachine to get our current running one\n", "result=step.list_executions(\n", " stateMachineArn=outputs['StateMachine'],\n", " statusFilter=\"RUNNING\"\n", ")['executions']\n", "\n", "if len(result) > 0:\n", " response = step.describe_execution(\n", " executionArn=result[0]['executionArn']\n", " )\n", " status=response['status']\n", " print(status,response['name'])\n", " #poll status till execution finishes\n", " while status == \"RUNNING\":\n", " print '.',\n", " time.sleep(5)\n", " status=step.describe_execution(executionArn=result[0]['executionArn'])['status']\n", " print()\n", " print(status)\n", "else:\n", " print(\"no running tasks\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Invoking the endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import google.protobuf.json_format as json_format\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "from sagemaker.predictor import json_serializer, csv_serializer\n", "from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer\n", "from sagemaker.predictor import RealTimePredictor\n", "\n", "predict=RealTimePredictor(outputs[\"SageMakerEndpoint\"], False, tf_json_serializer,tf_json_deserializer)\n", "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)\n", "\n", "for i in range(10):\n", " data = mnist.test.images[i].tolist()\n", " tensor_proto = tf.make_tensor_proto(\n", " values=np.asarray(data), \n", " shape=[1, len(data)], \n", " dtype=tf.float32)\n", " \n", " \n", " predict_response=predict.predict(tensor_proto)\n", " print(\"========================================\")\n", " label = np.argmax(mnist.test.labels[i])\n", " print(\"label is {}\".format(label))\n", " prediction = predict_response['outputs']['classes']['int64Val'][0]\n", " print(\"prediction is {}\".format(prediction))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_tensorflow_p27", "language": "python", "name": "conda_tensorflow_p27" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 2 }