{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training and hosting SageMaker Models using the Apache MXNet Module API\n", "\n", "In this example, we train a simple neural network using the Apache MXNet [Module API](https://mxnet.incubator.apache.org/api/python/module.html) and the MNIST dataset. The MNIST dataset is widely used for handwritten digit classification, and consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). The task at hand is to train a model using the 60,000 training images and subsequently test its classification accuracy on the 10,000 test images.\n", "\n", "### Setup\n", "\n", "First we need to define a few variables that will be needed later in the example." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "isConfigCell": true }, "outputs": [], "source": [ "from sagemaker import get_execution_role\n", "import json\n", "import boto3\n", "import time\n", "import os\n", "import time\n", "import tarfile\n", "from botocore.exceptions import ClientError\n", "cf = boto3.client('cloudformation')\n", "s3 = boto3.client('s3')\n", "sns = boto3.client('sns')\n", "step = boto3.client('stepfunctions')\n", "sagemaker = boto3.client('sagemaker-runtime')\n", "ssm=boto3.client('ssm')\n", "cf = boto3.client('cloudformation')\n", "\n", "with open('../config.json') as json_file: \n", " config = json.load(json_file)\n", "StackName=config[\"StackName\"]\n", "\n", "result=cf.describe_stacks(\n", " StackName=StackName\n", ")\n", "outputs={}\n", "for output in result['Stacks'][0]['Outputs']:\n", " outputs[output['OutputKey']]=output['OutputValue']\n", "\n", "with tarfile.open(\"script.tar.gz\", \"w:gz\") as tar:\n", " tar.add(os.getcwd(),arcname=\"\")\n", "\n", "s3.upload_file(\"script.tar.gz\",outputs[\"CodeBucket\"],\"script.tar.gz\")\n", "#IAM execution role that gives SageMaker access to resources in your AWS account.\n", "#We can use the SageMaker Python SDK to get the role from our notebook environment. \n", "role = get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to make sure the Sagebuild template is configured correctly for MXNET. the following code will set the stack configuration" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params=result[\"Stacks\"][0][\"Parameters\"]\n", "for n,i in enumerate(params):\n", " if(i[\"ParameterKey\"]==\"ConfigFramework\"):\n", " i[\"ParameterValue\"]=\"MXNET\" \n", "\n", "try:\n", " cf.update_stack(\n", " StackName=StackName,\n", " UsePreviousTemplate=True,\n", " Parameters=params,\n", " Capabilities=[\n", " 'CAPABILITY_NAMED_IAM',\n", " ]\n", " )\n", " waiter = cf.get_waiter('stack_update_complete')\n", " print(\"Waiting for stack update\")\n", " waiter.wait(\n", " StackName=StackName,\n", " WaiterConfig={\n", " 'Delay':10,\n", " 'MaxAttempts':600\n", " }\n", " )\n", "\n", "except ClientError as e:\n", " if(e.response[\"Error\"][\"Message\"]==\"No updates are to be performed.\"):\n", " pass\n", " else:\n", " raise e\n", "print(\"stack ready!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download Data to bucket" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataBucket=outputs[\"DataBucket\"]\n", "!aws s3 cp s3://sagemaker-sample-data-us-east-1/mxnet/mnist/train s3://$dataBucket/train/mnist --recursive\n", "!aws s3 cp s3://sagemaker-sample-data-us-east-1/mxnet/mnist/test s3://$dataBucket/test/mnist --recursive" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Update SageBuild Parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "store=outputs[\"ParameterStore\"]\n", "result=ssm.get_parameter(Name=store)\n", "\n", "params=json.loads(result[\"Parameter\"][\"Value\"])\n", "params[\"hyperparameters\"]={'learning_rate': 0.1}\n", "params[\"trainentrypoint\"]=\"mnist.py\"\n", "params[\"trainsourcefile\"]=\"s3://{}/script.tar.gz\".format(outputs[\"CodeBucket\"])\n", "params[\"hostentrypoint\"]=\"mnist.py\"\n", "params[\"hostsourcefile\"]=\"s3://{}/script.tar.gz\".format(outputs[\"CodeBucket\"])\n", "params[\"pyversion\"]=\"py2\"\n", "params[\"channels\"]={\n", " \"train\":{\n", " \"path\":\"train/mnist\"\n", " },\n", " \"test\":{\n", " \"path\":\"test/mnist\"\n", " }\n", "}\n", "\n", "ssm.put_parameter(\n", " Name=store,\n", " Type=\"String\",\n", " Overwrite=True,\n", " Value=json.dumps(params)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The training script\n", "\n", "The ``mnist.py`` script provides all the code we need for training and hosting a SageMaker model. The script we will use is adaptated from Apache MXNet [MNIST tutorial (https://mxnet.incubator.apache.org/tutorials/python/mnist.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Start Train/Deploy pipeline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "result=sns.publish(\n", " TopicArn=outputs['LaunchTopic'],\n", " Message=\"{}\" #message is not important, just publishing to topic starts build\n", ")\n", "print(result)\n", "time.sleep(5)\n", "#list all executions for our StateMachine to get our current running one\n", "result=step.list_executions(\n", " stateMachineArn=outputs['StateMachine'],\n", " statusFilter=\"RUNNING\"\n", ")['executions']\n", "\n", "if len(result) > 0:\n", " response = step.describe_execution(\n", " executionArn=result[0]['executionArn']\n", " )\n", " status=response['status']\n", " print(status,response['name'])\n", " #poll status till execution finishes\n", " while status == \"RUNNING\":\n", " print('.',end=\"\")\n", " time.sleep(5)\n", " status=step.describe_execution(executionArn=result[0]['executionArn'])['status']\n", " print()\n", " print(status)\n", "else:\n", " print(\"no running tasks\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making an inference request\n", "\n", "The request handling behavior of the Endpoint is determined by the ``mnist.py`` script. In this case, the script doesn't include any request handling functions, so the Endpoint will use the default handlers provided by SageMaker. These default handlers allow us to perform inference on input data encoded as a multi-dimensional JSON array.\n", "\n", "To see inference in action, draw a digit in the image box below. The pixel data from your drawing will be loaded into a ``data`` variable in this notebook. \n", "\n", "*Note: after drawing the image, you'll need to move to the next notebook cell.*" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import HTML\n", "HTML(open(\"input.html\").read())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can classify the handwritten digit:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "result=sagemaker.invoke_endpoint(\n", " EndpointName=outputs[\"SageMakerEndpoint\"],\n", " Body=json.dumps(data), \n", " ContentType=\"application/json\",\n", " Accept=\"application/json\"\n", ")\n", "\n", "response =json.loads(result['Body'].read().decode('utf-8'))\n", "labeled_predictions = list(zip(range(10), response[0]))\n", "print('Labeled predictions: ')\n", "print(labeled_predictions)\n", "\n", "labeled_predictions.sort(key=lambda label_and_prob: 1.0 - label_and_prob[1])\n", "print('Most likely answer: {}'.format(labeled_predictions[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_mxnet_p36", "language": "python", "name": "conda_mxnet_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 2 }