{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# TensorFlow script mode training and serving\n", "\n", "Script mode is a training script format for TensorFlow that lets you execute any TensorFlow training script in SageMaker with minimal modification. The [SageMaker Python SDK](https://github.com/aws/sagemaker-python-sdk) handles transferring your script to a SageMaker training instance. On the training instance, SageMaker's native TensorFlow support sets up training-related environment variables and executes your training script. In this tutorial, we use the SageMaker Python SDK to launch a training job and deploy the trained model.\n", "\n", "Script mode supports training with a Python script, a Python module, or a shell script. In this example, we use a Python script to train a classification model on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). In addition, this notebook demonstrates how to perform real time inference with the [SageMaker TensorFlow Serving container](https://github.com/aws/sagemaker-tensorflow-serving-container). The TensorFlow Serving container is the default inference method for script mode. For full documentation on the TensorFlow Serving container, please visit [here](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/deploying_tensorflow_serving.rst).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Set up the environment\n", "\n", "Let's start by setting up the environment:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sagemaker\n", "from sagemaker import get_execution_role\n", "import time\n", "\n", "sagemaker_session = sagemaker.Session()\n", "\n", "role = get_execution_role()\n", "region = sagemaker_session.boto_session.region_name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training Data\n", "\n", "The MNIST dataset has been loaded to the public S3 buckets ``sagemaker-sample-data-`` under the prefix ``tensorflow/mnist``. There are four ``.npy`` file under this prefix:\n", "* ``train_data.npy``\n", "* ``eval_data.npy``\n", "* ``train_labels.npy``\n", "* ``eval_labels.npy``" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "training_data_uri = 's3://sagemaker-sample-data-{}/tensorflow/mnist'.format(region)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Construct a script for distributed training\n", "\n", "This tutorial's training script was adapted from TensorFlow's official [CNN MNIST example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/layers/cnn_mnist.py). We have modified it to handle the ``model_dir`` parameter passed in by SageMaker. This is an S3 path which can be used for data sharing during distributed training and checkpointing and/or model persistence. We have also added an argument-parsing function to handle processing training-related variables.\n", "\n", "At the end of the training job we have added a step to export the trained model to the path stored in the environment variable ``SM_MODEL_DIR``, which always points to ``/opt/ml/model``. This is critical because SageMaker uploads all the model artifacts in this folder to S3 at end of training.\n", "\n", "For more documentation on SageMaker 'script' mode, please, refer to https://sagemaker.readthedocs.io/en/stable/using_tf.html\n", "\n", "Here is the entire script:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[37m# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.\u001b[39;49;00m\r\n", "\u001b[37m#\u001b[39;49;00m\r\n", "\u001b[37m# Licensed under the Apache License, Version 2.0 (the \"License\"). You\u001b[39;49;00m\r\n", "\u001b[37m# may not use this file except in compliance with the License. A copy of\u001b[39;49;00m\r\n", "\u001b[37m# the License is located at\u001b[39;49;00m\r\n", "\u001b[37m#\u001b[39;49;00m\r\n", "\u001b[37m# http://aws.amazon.com/apache2.0/\u001b[39;49;00m\r\n", "\u001b[37m#\u001b[39;49;00m\r\n", "\u001b[37m# or in the \"license\" file accompanying this file. This file is\u001b[39;49;00m\r\n", "\u001b[37m# distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF\u001b[39;49;00m\r\n", "\u001b[37m# ANY KIND, either express or implied. See the License for the specific\u001b[39;49;00m\r\n", "\u001b[37m# language governing permissions and limitations under the License.\u001b[39;49;00m\r\n", "\u001b[33m\"\"\"Convolutional Neural Network Estimator for MNIST, built with tf.layers.\"\"\"\u001b[39;49;00m\r\n", "\r\n", "\u001b[34mfrom\u001b[39;49;00m \u001b[04m\u001b[36m__future__\u001b[39;49;00m \u001b[34mimport\u001b[39;49;00m absolute_import\r\n", "\u001b[34mfrom\u001b[39;49;00m \u001b[04m\u001b[36m__future__\u001b[39;49;00m \u001b[34mimport\u001b[39;49;00m division\r\n", "\u001b[34mfrom\u001b[39;49;00m \u001b[04m\u001b[36m__future__\u001b[39;49;00m \u001b[34mimport\u001b[39;49;00m print_function\r\n", "\r\n", "\u001b[34mimport\u001b[39;49;00m \u001b[04m\u001b[36mnumpy\u001b[39;49;00m \u001b[34mas\u001b[39;49;00m \u001b[04m\u001b[36mnp\u001b[39;49;00m\r\n", "\u001b[34mimport\u001b[39;49;00m \u001b[04m\u001b[36mtensorflow\u001b[39;49;00m \u001b[34mas\u001b[39;49;00m \u001b[04m\u001b[36mtf\u001b[39;49;00m\r\n", "\u001b[34mimport\u001b[39;49;00m \u001b[04m\u001b[36mos\u001b[39;49;00m\r\n", "\u001b[34mimport\u001b[39;49;00m \u001b[04m\u001b[36mjson\u001b[39;49;00m\r\n", "\u001b[34mimport\u001b[39;49;00m \u001b[04m\u001b[36margparse\u001b[39;49;00m\r\n", "\u001b[34mfrom\u001b[39;49;00m \u001b[04m\u001b[36mtensorflow.python.platform\u001b[39;49;00m \u001b[34mimport\u001b[39;49;00m tf_logging\r\n", "\u001b[34mimport\u001b[39;49;00m \u001b[04m\u001b[36mlogging\u001b[39;49;00m \u001b[34mas\u001b[39;49;00m \u001b[04m\u001b[36m_logging\u001b[39;49;00m\r\n", "\u001b[34mimport\u001b[39;49;00m \u001b[04m\u001b[36msys\u001b[39;49;00m \u001b[34mas\u001b[39;49;00m \u001b[04m\u001b[36m_sys\u001b[39;49;00m\r\n", "\r\n", "\r\n", "\u001b[34mdef\u001b[39;49;00m \u001b[32mcnn_model_fn\u001b[39;49;00m(features, labels, mode):\r\n", " \u001b[33m\"\"\"Model function for CNN.\"\"\"\u001b[39;49;00m\r\n", " \u001b[37m# Input Layer\u001b[39;49;00m\r\n", " \u001b[37m# Reshape X to 4-D tensor: [batch_size, width, height, channels]\u001b[39;49;00m\r\n", " \u001b[37m# MNIST images are 28x28 pixels, and have one color channel\u001b[39;49;00m\r\n", " input_layer = tf.reshape(features[\u001b[33m\"\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m], [-\u001b[34m1\u001b[39;49;00m, \u001b[34m28\u001b[39;49;00m, \u001b[34m28\u001b[39;49;00m, \u001b[34m1\u001b[39;49;00m])\r\n", "\r\n", " \u001b[37m# Convolutional Layer #1\u001b[39;49;00m\r\n", " \u001b[37m# Computes 32 features using a 5x5 filter with ReLU activation.\u001b[39;49;00m\r\n", " \u001b[37m# Padding is added to preserve width and height.\u001b[39;49;00m\r\n", " \u001b[37m# Input Tensor Shape: [batch_size, 28, 28, 1]\u001b[39;49;00m\r\n", " \u001b[37m# Output Tensor Shape: [batch_size, 28, 28, 32]\u001b[39;49;00m\r\n", " conv1 = tf.layers.conv2d(\r\n", " inputs=input_layer,\r\n", " filters=\u001b[34m32\u001b[39;49;00m,\r\n", " kernel_size=[\u001b[34m5\u001b[39;49;00m, \u001b[34m5\u001b[39;49;00m],\r\n", " padding=\u001b[33m\"\u001b[39;49;00m\u001b[33msame\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m,\r\n", " activation=tf.nn.relu)\r\n", "\r\n", " \u001b[37m# Pooling Layer #1\u001b[39;49;00m\r\n", " \u001b[37m# First max pooling layer with a 2x2 filter and stride of 2\u001b[39;49;00m\r\n", " \u001b[37m# Input Tensor Shape: [batch_size, 28, 28, 32]\u001b[39;49;00m\r\n", " \u001b[37m# Output Tensor Shape: [batch_size, 14, 14, 32]\u001b[39;49;00m\r\n", " pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[\u001b[34m2\u001b[39;49;00m, \u001b[34m2\u001b[39;49;00m], strides=\u001b[34m2\u001b[39;49;00m)\r\n", "\r\n", " \u001b[37m# Convolutional Layer #2\u001b[39;49;00m\r\n", " \u001b[37m# Computes 64 features using a 5x5 filter.\u001b[39;49;00m\r\n", " \u001b[37m# Padding is added to preserve width and height.\u001b[39;49;00m\r\n", " \u001b[37m# Input Tensor Shape: [batch_size, 14, 14, 32]\u001b[39;49;00m\r\n", " \u001b[37m# Output Tensor Shape: [batch_size, 14, 14, 64]\u001b[39;49;00m\r\n", " conv2 = tf.layers.conv2d(\r\n", " inputs=pool1,\r\n", " filters=\u001b[34m64\u001b[39;49;00m,\r\n", " kernel_size=[\u001b[34m5\u001b[39;49;00m, \u001b[34m5\u001b[39;49;00m],\r\n", " padding=\u001b[33m\"\u001b[39;49;00m\u001b[33msame\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m,\r\n", " activation=tf.nn.relu)\r\n", "\r\n", " \u001b[37m# Pooling Layer #2\u001b[39;49;00m\r\n", " \u001b[37m# Second max pooling layer with a 2x2 filter and stride of 2\u001b[39;49;00m\r\n", " \u001b[37m# Input Tensor Shape: [batch_size, 14, 14, 64]\u001b[39;49;00m\r\n", " \u001b[37m# Output Tensor Shape: [batch_size, 7, 7, 64]\u001b[39;49;00m\r\n", " pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[\u001b[34m2\u001b[39;49;00m, \u001b[34m2\u001b[39;49;00m], strides=\u001b[34m2\u001b[39;49;00m)\r\n", "\r\n", " \u001b[37m# Flatten tensor into a batch of vectors\u001b[39;49;00m\r\n", " \u001b[37m# Input Tensor Shape: [batch_size, 7, 7, 64]\u001b[39;49;00m\r\n", " \u001b[37m# Output Tensor Shape: [batch_size, 7 * 7 * 64]\u001b[39;49;00m\r\n", " pool2_flat = tf.reshape(pool2, [-\u001b[34m1\u001b[39;49;00m, \u001b[34m7\u001b[39;49;00m * \u001b[34m7\u001b[39;49;00m * \u001b[34m64\u001b[39;49;00m])\r\n", "\r\n", " \u001b[37m# Dense Layer\u001b[39;49;00m\r\n", " \u001b[37m# Densely connected layer with 1024 neurons\u001b[39;49;00m\r\n", " \u001b[37m# Input Tensor Shape: [batch_size, 7 * 7 * 64]\u001b[39;49;00m\r\n", " \u001b[37m# Output Tensor Shape: [batch_size, 1024]\u001b[39;49;00m\r\n", " dense = tf.layers.dense(inputs=pool2_flat, units=\u001b[34m1024\u001b[39;49;00m, activation=tf.nn.relu)\r\n", "\r\n", " \u001b[37m# Add dropout operation; 0.6 probability that element will be kept\u001b[39;49;00m\r\n", " dropout = tf.layers.dropout(\r\n", " inputs=dense, rate=\u001b[34m0.4\u001b[39;49;00m, training=mode == tf.estimator.ModeKeys.TRAIN)\r\n", "\r\n", " \u001b[37m# Logits layer\u001b[39;49;00m\r\n", " \u001b[37m# Input Tensor Shape: [batch_size, 1024]\u001b[39;49;00m\r\n", " \u001b[37m# Output Tensor Shape: [batch_size, 10]\u001b[39;49;00m\r\n", " logits = tf.layers.dense(inputs=dropout, units=\u001b[34m10\u001b[39;49;00m)\r\n", "\r\n", " predictions = {\r\n", " \u001b[37m# Generate predictions (for PREDICT and EVAL mode)\u001b[39;49;00m\r\n", " \u001b[33m\"\u001b[39;49;00m\u001b[33mclasses\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m: tf.argmax(\u001b[36minput\u001b[39;49;00m=logits, axis=\u001b[34m1\u001b[39;49;00m),\r\n", " \u001b[37m# Add `softmax_tensor` to the graph. It is used for PREDICT and by the\u001b[39;49;00m\r\n", " \u001b[37m# `logging_hook`.\u001b[39;49;00m\r\n", " \u001b[33m\"\u001b[39;49;00m\u001b[33mprobabilities\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m: tf.nn.softmax(logits, name=\u001b[33m\"\u001b[39;49;00m\u001b[33msoftmax_tensor\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m)\r\n", " }\r\n", " \u001b[34mif\u001b[39;49;00m mode == tf.estimator.ModeKeys.PREDICT:\r\n", " \u001b[34mreturn\u001b[39;49;00m tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)\r\n", "\r\n", " \u001b[37m# Calculate Loss (for both TRAIN and EVAL modes)\u001b[39;49;00m\r\n", " loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)\r\n", "\r\n", " \u001b[37m# Configure the Training Op (for TRAIN mode)\u001b[39;49;00m\r\n", " \u001b[34mif\u001b[39;49;00m mode == tf.estimator.ModeKeys.TRAIN:\r\n", " optimizer = tf.train.GradientDescentOptimizer(learning_rate=\u001b[34m0.001\u001b[39;49;00m)\r\n", " train_op = optimizer.minimize(\r\n", " loss=loss,\r\n", " global_step=tf.train.get_global_step())\r\n", " \u001b[34mreturn\u001b[39;49;00m tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)\r\n", "\r\n", " \u001b[37m# Add evaluation metrics (for EVAL mode)\u001b[39;49;00m\r\n", " eval_metric_ops = {\r\n", " \u001b[33m\"\u001b[39;49;00m\u001b[33maccuracy\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m: tf.metrics.accuracy(\r\n", " labels=labels, predictions=predictions[\u001b[33m\"\u001b[39;49;00m\u001b[33mclasses\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m])}\r\n", " \u001b[34mreturn\u001b[39;49;00m tf.estimator.EstimatorSpec(\r\n", " mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)\r\n", "\r\n", "\u001b[34mdef\u001b[39;49;00m \u001b[32m_load_training_data\u001b[39;49;00m(base_dir):\r\n", " x_train = np.load(os.path.join(base_dir, \u001b[33m'\u001b[39;49;00m\u001b[33mtrain_data.npy\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m))\r\n", " y_train = np.load(os.path.join(base_dir, \u001b[33m'\u001b[39;49;00m\u001b[33mtrain_labels.npy\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m))\r\n", " \u001b[34mreturn\u001b[39;49;00m x_train, y_train\r\n", "\r\n", "\u001b[34mdef\u001b[39;49;00m \u001b[32m_load_testing_data\u001b[39;49;00m(base_dir):\r\n", " x_test = np.load(os.path.join(base_dir, \u001b[33m'\u001b[39;49;00m\u001b[33meval_data.npy\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m))\r\n", " y_test = np.load(os.path.join(base_dir, \u001b[33m'\u001b[39;49;00m\u001b[33meval_labels.npy\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m))\r\n", " \u001b[34mreturn\u001b[39;49;00m x_test, y_test\r\n", "\r\n", "\u001b[34mdef\u001b[39;49;00m \u001b[32m_parse_args\u001b[39;49;00m():\r\n", "\r\n", " parser = argparse.ArgumentParser()\r\n", "\r\n", " \u001b[37m# Data, model, and output directories\u001b[39;49;00m\r\n", " \u001b[37m# model_dir is always passed in from SageMaker. By default this is a S3 path under the default bucket.\u001b[39;49;00m\r\n", " parser.add_argument(\u001b[33m'\u001b[39;49;00m\u001b[33m--model_dir\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m, \u001b[36mtype\u001b[39;49;00m=\u001b[36mstr\u001b[39;49;00m)\r\n", " parser.add_argument(\u001b[33m'\u001b[39;49;00m\u001b[33m--sm-model-dir\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m, \u001b[36mtype\u001b[39;49;00m=\u001b[36mstr\u001b[39;49;00m, default=os.environ.get(\u001b[33m'\u001b[39;49;00m\u001b[33mSM_MODEL_DIR\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m))\r\n", " parser.add_argument(\u001b[33m'\u001b[39;49;00m\u001b[33m--train\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m, \u001b[36mtype\u001b[39;49;00m=\u001b[36mstr\u001b[39;49;00m, default=os.environ.get(\u001b[33m'\u001b[39;49;00m\u001b[33mSM_CHANNEL_TRAINING\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m))\r\n", " parser.add_argument(\u001b[33m'\u001b[39;49;00m\u001b[33m--hosts\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m, \u001b[36mtype\u001b[39;49;00m=\u001b[36mlist\u001b[39;49;00m, default=json.loads(os.environ.get(\u001b[33m'\u001b[39;49;00m\u001b[33mSM_HOSTS\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m)))\r\n", " parser.add_argument(\u001b[33m'\u001b[39;49;00m\u001b[33m--current-host\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m, \u001b[36mtype\u001b[39;49;00m=\u001b[36mstr\u001b[39;49;00m, default=os.environ.get(\u001b[33m'\u001b[39;49;00m\u001b[33mSM_CURRENT_HOST\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m))\r\n", "\r\n", " \u001b[34mreturn\u001b[39;49;00m parser.parse_known_args()\r\n", "\r\n", "\u001b[34mdef\u001b[39;49;00m \u001b[32mserving_input_fn\u001b[39;49;00m():\r\n", " inputs = {\u001b[33m'\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m'\u001b[39;49;00m: tf.placeholder(tf.float32, [\u001b[36mNone\u001b[39;49;00m, \u001b[34m784\u001b[39;49;00m])}\r\n", " \u001b[34mreturn\u001b[39;49;00m tf.estimator.export.ServingInputReceiver(inputs, inputs)\r\n", "\r\n", "\u001b[34mif\u001b[39;49;00m \u001b[31m__name__\u001b[39;49;00m == \u001b[33m\"\u001b[39;49;00m\u001b[33m__main__\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m:\r\n", " args, unknown = _parse_args()\r\n", "\r\n", " train_data, train_labels = _load_training_data(args.train)\r\n", " eval_data, eval_labels = _load_testing_data(args.train)\r\n", "\r\n", " \u001b[37m# Create the Estimator\u001b[39;49;00m\r\n", " mnist_classifier = tf.estimator.Estimator(\r\n", " model_fn=cnn_model_fn, model_dir=args.model_dir)\r\n", "\r\n", " \u001b[37m# Set up logging for predictions\u001b[39;49;00m\r\n", " \u001b[37m# Log the values in the \"Softmax\" tensor with label \"probabilities\"\u001b[39;49;00m\r\n", " tensors_to_log = {\u001b[33m\"\u001b[39;49;00m\u001b[33mprobabilities\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m: \u001b[33m\"\u001b[39;49;00m\u001b[33msoftmax_tensor\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m}\r\n", " logging_hook = tf.train.LoggingTensorHook(\r\n", " tensors=tensors_to_log, every_n_iter=\u001b[34m50\u001b[39;49;00m)\r\n", "\r\n", " \u001b[37m# Train the model\u001b[39;49;00m\r\n", " train_input_fn = tf.estimator.inputs.numpy_input_fn(\r\n", " x={\u001b[33m\"\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m: train_data},\r\n", " y=train_labels,\r\n", " batch_size=\u001b[34m100\u001b[39;49;00m,\r\n", " num_epochs=\u001b[36mNone\u001b[39;49;00m,\r\n", " shuffle=\u001b[36mTrue\u001b[39;49;00m)\r\n", "\r\n", " \u001b[37m# Evaluate the model and print results\u001b[39;49;00m\r\n", " eval_input_fn = tf.estimator.inputs.numpy_input_fn(\r\n", " x={\u001b[33m\"\u001b[39;49;00m\u001b[33mx\u001b[39;49;00m\u001b[33m\"\u001b[39;49;00m: eval_data},\r\n", " y=eval_labels,\r\n", " num_epochs=\u001b[34m1\u001b[39;49;00m,\r\n", " shuffle=\u001b[36mFalse\u001b[39;49;00m)\r\n", "\r\n", " train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=\u001b[34m20000\u001b[39;49;00m)\r\n", " eval_spec = tf.estimator.EvalSpec(eval_input_fn)\r\n", " tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec)\r\n", "\r\n", " \u001b[34mif\u001b[39;49;00m args.current_host == args.hosts[\u001b[34m0\u001b[39;49;00m]:\r\n", " mnist_classifier.export_savedmodel(args.sm_model_dir, serving_input_fn)\r\n" ] } ], "source": [ "!pygmentize 'mnist.py'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Create a training job using the `TensorFlow` estimator\n", "\n", "The `sagemaker.tensorflow.TensorFlow` estimator handles locating the script mode container, uploading your script to a S3 location and creating a SageMaker training job. Let's call out a couple important parameters here:\n", "\n", "* `py_version` is set to `'py3'` to indicate that we are using script mode since legacy mode supports only Python 2. Though Python 2 will be deprecated soon, you can use script mode with Python 2 by setting `py_version` to `'py2'` and `script_mode` to `True`.\n", "\n", "* `distributions` is used to configure the distributed training setup. It's required only if you are doing distributed training either across a cluster of instances or across multiple GPUs. Here we are using parameter servers as the distributed training schema. SageMaker training jobs run on homogeneous clusters. To make parameter server more performant in the SageMaker setup, we run a parameter server on every instance in the cluster, so there is no need to specify the number of parameter servers to launch. Script mode also supports distributed training with [Horovod](https://github.com/horovod/horovod). You can find the full documentation on how to configure `distributions` [here](https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/tensorflow#distributed-training). \n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": false }, "outputs": [], "source": [ "from sagemaker.tensorflow import TensorFlow\n", "train_instance_count = 2\n", "#train_instance_type = 'ml.p2.xlarge' \n", "#train_instance_type='ml.c5.9xlarge'\n", "train_instance_type='ml.m4.4xlarge'\n", "#train_instance_type='ml.m5.2xlarge'\n", "#train_instance_type='ml.m5.4xlarge'\n", "#train_instance_type='ml.c4.2xlarge'\n", "#train_instance_type='ml.c5.2xlarge'\n", "\n", "mnist_estimator = TensorFlow(entry_point='mnist.py',\n", " role=role,\n", " train_instance_count=train_instance_count,\n", " train_instance_type=train_instance_type,\n", " framework_version='1.12',\n", " py_version = 'py3',\n", " distributions = {'parameter_server': {'enabled': True}})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Calling ``fit``\n", "\n", "To start a training job, we call `estimator.fit(training_data_uri)`.\n", "\n", "An S3 location is used here as the input. `fit` creates a default channel named `'training'`, which points to this S3 location. In the training script we can then access the training data from the location stored in `SM_CHANNEL_TRAINING`. `fit` accepts a couple other types of input as well. See the API doc [here](https://sagemaker.readthedocs.io/en/stable/estimators.html#sagemaker.estimator.EstimatorBase.fit) for details.\n", "\n", "When training starts, the TensorFlow container executes mnist.py, passing `hyperparameters` and `model_dir` from the estimator as script arguments. Because we didn't define either in this example, no hyperparameters are passed, and `model_dir` defaults to `s3:///`, so the script execution is as follows:\n", "```bash\n", "python mnist.py --model_dir s3:///\n", "```\n", "When training is complete, the training job will upload the saved model for TensorFlow serving." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "print(\"train_instance_type - \",train_instance_type )\n", "mnist_estimator.fit(training_data_uri)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# Deploy the trained model to an endpoint\n", "\n", "The `deploy()` method creates a SageMaker model, which is then deployed to an endpoint to serve prediction requests in real time. We will use the TensorFlow Serving container for the endpoint, because we trained with script mode. This serving container runs an implementation of a web server that is compatible with SageMaker hosting protocol. The [Using your own inference code]() document explains how SageMaker runs inference containers." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---------------------------------------------------------------------------!CPU times: user 383 ms, sys: 26.3 ms, total: 410 ms\n", "Wall time: 6min 20s\n" ] } ], "source": [ "%%time\n", "#instance_type_predict = 'ml.c5.9xlarge'\n", "host_instance_type = 'ml.m4.xlarge'\n", "predictor = mnist_estimator.deploy(initial_instance_count=1, instance_type=host_instance_type)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Invoke the endpoint\n", "\n", "Let's download the training data and use that as input for inference." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "download: s3://sagemaker-sample-data-us-east-1/tensorflow/mnist/train_data.npy to ./train_data.npy\n", "download: s3://sagemaker-sample-data-us-east-1/tensorflow/mnist/train_labels.npy to ./train_labels.npy\n" ] } ], "source": [ "import numpy as np\n", "\n", "!aws --region {region} s3 cp s3://sagemaker-sample-data-{region}/tensorflow/mnist/train_data.npy train_data.npy\n", "!aws --region {region} s3 cp s3://sagemaker-sample-data-{region}/tensorflow/mnist/train_labels.npy train_labels.npy\n", "\n", "train_data = np.load('train_data.npy')\n", "train_labels = np.load('train_labels.npy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The formats of the input and the output data correspond directly to the request and response formats of the `Predict` method in the [TensorFlow Serving REST API](https://www.tensorflow.org/serving/api_rest). SageMaker's TensforFlow Serving endpoints can also accept additional input formats that are not part of the TensorFlow REST API, including the simplified JSON format, line-delimited JSON objects (\"jsons\" or \"jsonlines\"), and CSV data.\n", "\n", "In this example we are using a `numpy` array as input, which will be serialized into the simplified JSON format. In addtion, TensorFlow serving can also process multiple items at once as you can see in the following code. You can find the complete documentation on how to make predictions against a TensorFlow serving SageMaker endpoint [here](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/tensorflow/deploying_tensorflow_serving.rst#making-predictions-against-a-sagemaker-endpoint)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "prediction is 7, label is 7, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 8, label is 8, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 9, label is 9, matched: True\n", "prediction is 8, label is 8, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 3, label is 2, matched: False\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADmJJREFUeJzt3X+MFfW5x/HPA1IjQiJau66WCFZzAxoDzSpN6o/eWAlsqtjEKPyha2q6mNR4G40p8ZpcorkGbyw3/ccm24Bdml7pjT8JYlvES21jQ1h1i7i0lVsXgSysCrE2aBD3uX/s7O0W93zPcmbOmbP7vF/JZs+Z58zM4+BnZ+bMOfM1dxeAeKaU3QCAchB+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBndbIlZkZHycE6szdbTyvy7XnN7MlZvYnM9trZqvyLAtAY1mtn+03s6mS/izpekkHJO2UtMLd+xLzsOcH6qwRe/4rJe1197+4+3FJGyUty7E8AA2UJ/wXSNo/6vmBbNo/MLNOM+sxs54c6wJQsLq/4efuXZK6JA77gWaSZ89/UNLsUc+/nE0DMAHkCf9OSZeY2Vwz+4Kk5ZI2FdMWgHqr+bDf3U+Y2d2SfiVpqqT17v5WYZ0BqKuaL/XVtDLO+YG6a8iHfABMXIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8EVfMQ3ZJkZv2SPpL0maQT7t5WRFPRnHfeecn60qVLk/V58+bVVJOk9vb2ZH3t2rXJ+pYtW5L1PXv2VKx9/PHHyXk//PDDZH3q1KnJ+u23316xNn369OS8XV1dyfqnn36arE8EucKf+Wd3f7+A5QBoIA77gaDyht8l/drMXjOzziIaAtAYeQ/7r3L3g2b2JUlbzeyP7v7K6BdkfxT4wwA0mVx7fnc/mP0elPSspCvHeE2Xu7fxZiDQXGoOv5mdaWYzRx5LWixpd1GNAaivPIf9LZKeNbOR5fyXu/+ykK4A1J25e+NWZta4lTWRjo6OZH39+vXJeiP/jU6W/XGvKE9v77zzTrJ+2223JevXXHNNsv7II4+cck8jFixYkKzv3t28B7nunv5Hy3CpDwiK8ANBEX4gKMIPBEX4gaAIPxAUl/oKcP755yfru3btStZnzZqVrNfz32jfvn3J+pw5c5L1iXoZ8oMPPkjOu2jRomS9v78/WS8Tl/oAJBF+ICjCDwRF+IGgCD8QFOEHgiL8QFBF3L03vOuuuy5ZP+uss3It/7nnnkvWV69eXfOyq13vPvfcc5P1av9tTzzxRMXahRdemJw3r6NHj1as3Xrrrcl5m/k6flHY8wNBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUFznL8Bdd92Va/5jx44l6729vcn6kiVLal73zp07k/UrrrgiWb/hhhuS9Xpfy0/Zv39/xdr27dsb10iTYs8PBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0FVvW+/ma2X9C1Jg+5+WTbtbEm/kDRHUr+kW9y98pen/76sSXnf/jVr1iTr9913X7J+2mnpj1sMDQ2dck9FmTIlvX/I09snn3ySrD/66KPJ+v3335+sT58+vWKt2mcjtm7dmqw3syLv2/9TSSdvqVWStrn7JZK2Zc8BTCBVw+/ur0g6ctLkZZK6s8fdkm4quC8AdVbrOX+Luw9kjw9JaimoHwANkvuz/e7uqXN5M+uU1Jl3PQCKVeue/7CZtUpS9nuw0gvdvcvd29y9rcZ1AaiDWsO/SVJH9rhD0vPFtAOgUaqG38yelPR7Sf9kZgfM7E5JayRdb2ZvS/pm9hzABFL1nN/dV1QopW9WH8iqVekrndXufX/HHXck69U+i1FP1a7jv/vuu8n6G2+8UbH22GOPJed99dVXk/UZM2Yk6/fee2/F2mS+zj9efMIPCIrwA0ERfiAowg8ERfiBoAg/EBS37m6Ahx9+OFnfuHFjgzop3u7du5P1gYGBZL0s8+fPL7uF0rHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGguM7fAP39/bnqGNu8efNqnrevr6/ATiYm9vxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EFTVIboLXdkkHaIb9dHWlh7kaceOHcn64GDFgaR09dVXJ+fdu3dvst7MihyiG8AkRPiBoAg/EBThB4Ii/EBQhB8IivADQVX9Pr+ZrZf0LUmD7n5ZNm21pO9Kei972QPuvqVeTWJyOuOMM5L1F154IVmfMiW97zp06FDF2kS+jl+U8ez5fypprMHM/9PdF2Q/BB+YYKqG391fkXSkAb0AaKA85/x3m9kuM1tvZrMK6whAQ9Qa/h9L+oqkBZIGJP2w0gvNrNPMesysp8Z1AaiDmsLv7ofd/TN3H5L0E0lXJl7b5e5t7p7+lgaAhqop/GbWOurptyWlh2oF0HTGc6nvSUnfkPRFMzsg6d8kfcPMFkhySf2SVtaxRwB1UDX87r5ijMnr6tALJqGZM2dWrHV3dyfnPeecc5L1oaGhZH3z5s3JenR8wg8IivADQRF+ICjCDwRF+IGgCD8QFEN0o66WL19esXbjjTfmWnZvb2+y/vjjj+da/mTHnh8IivADQRF+ICjCDwRF+IGgCD8QFOEHguI6P5Kq3V573br0t7vb29trXndfX1+y/uCDDybrAwMDNa87Avb8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxCUuXvjVmbWuJU1UEdHR7K+dOnSZP3aa69N1uv5b9TTkx5FbdGiRcl6tdtr53HxxRcn6/39/XVb90Tm7jae17HnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgql7nN7PZkjZIapHkkrrc/UdmdrakX0iaI6lf0i3ufrTKsibsdf7FixdXrL344ou5lj1lSvpvcLWhqOspb2/Hjx+vWFu5cmVy3g0bNiTrGFuR1/lPSLrP3edL+pqk75nZfEmrJG1z90skbcueA5ggqobf3Qfc/fXs8UeS9ki6QNIySd3Zy7ol3VSvJgEU75TO+c1sjqSFknZIanH3kfskHdLwaQGACWLc9/AzsxmSnpb0fXf/q9nfTyvc3Sudz5tZp6TOvI0CKNa49vxmNk3Dwf+5uz+TTT5sZq1ZvVXS4FjzunuXu7e5e1sRDQMoRtXw2/Aufp2kPe6+dlRpk6SRr7N1SHq++PYA1Mt4Dvu/Luk2SW+a2ciYyA9IWiPpv83sTkn7JN1Snxabw5IlSyrW8n7lttrlskZ+7fpkeXtLfWV4+/bttbSEglQNv7v/TlKl64bXFdsOgEbhE35AUIQfCIrwA0ERfiAowg8ERfiBoLh1d2bmzJnJ+ssvv1yxtnDhwlzrHv1R6bGUeZ2/nr1VG0L78ssvT9aPHk1+gzwsbt0NIInwA0ERfiAowg8ERfiBoAg/EBThB4Ia9228JruLLrooWc97Lb8smzdvTtZfeumlZL3adf577rknWZ87d27FWmtra3LeadOmJevIhz0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwTFdf7Me++9l6w/9dRTFWs333xzrnUfO3YsWX/ooYeS9e7u7oq1I0eOJOc9ceJEsl5NtWG0Tz/99Iq1Sy+9NDlvte2CfNjzA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQVe/bb2azJW2Q1CLJJXW5+4/MbLWk70oauUD+gLtvqbKspr1vPzBZjPe+/eMJf6ukVnd/3cxmSnpN0k2SbpH0N3d/bLxNEX6g/sYb/qqf8HP3AUkD2eOPzGyPpAvytQegbKd0zm9mcyQtlLQjm3S3me0ys/VmNqvCPJ1m1mNmPbk6BVCocY/VZ2YzJP1G0r+7+zNm1iLpfQ2/D/Cwhk8NvlNlGRz2A3VW2Dm/JJnZNEmbJf3K3deOUZ8jabO7X1ZlOYQfqLPCBuq04du3rpO0Z3TwszcCR3xb0u5TbRJAecbzbv9Vkn4r6U1JQ9nkByStkLRAw4f9/ZJWZm8OppbFnh+os0IP+4tC+IH6K+ywH8DkRPiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiq0UN0vy9p36jnX8ymNaNm7a1Z+5LorVZF9nbheF/Y0O/zf27lZj3u3lZaAwnN2luz9iXRW63K6o3DfiAowg8EVXb4u0pef0qz9tasfUn0VqtSeiv1nB9Aecre8wMoSSnhN7MlZvYnM9trZqvK6KESM+s3szfNrLfsIcayYdAGzWz3qGlnm9lWM3s7+z3mMGkl9bbazA5m267XzNpL6m22mf2PmfWZ2Vtm9i/Z9FK3XaKvUrZbww/7zWyqpD9Lul7SAUk7Ja1w976GNlKBmfVLanP30q8Jm9k1kv4macPIaEhm9h+Sjrj7muwP5yx3/0GT9LZapzhyc516qzSy9B0qcdsVOeJ1EcrY818paa+7/8Xdj0vaKGlZCX00PXd/RdKRkyYvk9SdPe7W8P88DVeht6bg7gPu/nr2+CNJIyNLl7rtEn2VoozwXyBp/6jnB9RcQ367pF+b2Wtm1ll2M2NoGTUy0iFJLWU2M4aqIzc30kkjSzfNtqtlxOui8Ybf513l7l+VtFTS97LD26bkw+dszXS55seSvqLhYdwGJP2wzGaykaWflvR9d//r6FqZ226MvkrZbmWE/6Ck2aOefzmb1hTc/WD2e1DSsxo+TWkmh0cGSc1+D5bcz/9z98Pu/pm7D0n6iUrcdtnI0k9L+rm7P5NNLn3bjdVXWdutjPDvlHSJmc01sy9IWi5pUwl9fI6ZnZm9ESMzO1PSYjXf6MObJHVkjzskPV9iL/+gWUZurjSytEredk034rW7N/xHUruG3/H/X0n/WkYPFfq6SNIfsp+3yu5N0pMaPgz8VMPvjdwp6RxJ2yS9LeklSWc3UW8/0/Bozrs0HLTWknq7SsOH9Lsk9WY/7WVvu0RfpWw3PuEHBMUbfkBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgvo/kd65TwXQV2UAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "prediction is 7, label is 7, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 2, label is 2, matched: True\n", "prediction is 9, label is 9, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 7, label is 7, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 9, label is 9, matched: True\n", "prediction is 7, label is 7, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 5, label is 5, matched: True\n", "prediction is 5, label is 5, matched: True\n", "prediction is 8, label is 8, matched: True\n", "prediction is 8, label is 8, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 8, label is 8, matched: True\n", "prediction is 7, label is 7, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 1, label is 8, matched: False\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADUFJREFUeJzt3W+IXXV+x/HPp+lGMNkHscEQjZhtiMoSMGlGKVTLltYlhmDcJ0HRGlE6q2ywC0X8UzCBUpRitvhAFrJsSLZs3W00Yly0u9tQ1i2WaCJpYtRdp3HiZohJo6trfBKj3z6Yk3bUOb87uf/Onfm+XzDMved7zzlfLvnknHt/58zPESEA+fxe0w0AaAbhB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+Q1O/3c2e2uZwQ6LGI8FRe19GR3/Yq27+yPWL7/k62BaC/3O61/bZnSfq1pOskHZX0sqSbI+K1wjoc+YEe68eR/2pJIxFxOCJOS/qRpLUdbA9AH3US/osl/WbC86PVss+wPWx7r+29HewLQJf1/Au/iNgiaYvEaT8wSDo58o9JumTC80XVMgDTQCfhf1nSUttfsT1b0k2SdnWnLQC91vZpf0Scsb1B0k8lzZK0NSIOda0zAD3V9lBfWzvjMz/Qc325yAfA9EX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJ9naIb/bd58+Zi/bbbbivWDxw40NH6Y2PM4zKoOPIDSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFIdzdJre1TSh5I+kXQmIoZavJ5Zenvglltuqa1t27atuO6sWbM62ner7Q8PD9fWzpw509G+MbmpztLbjYt8/iwiTnZhOwD6iNN+IKlOwx+SfmZ7n+368zsAA6fT0/5rImLM9oWSfm77jYh4YeILqv8U+I8BGDAdHfkjYqz6fULS05KunuQ1WyJiqNWXgQD6q+3w255j+8tnH0v6uqRXu9UYgN7q5LR/gaSnbZ/dzj9HxL92pSsAPdfROP8574xx/p546623amuXXnppHzv5ojvuuKO21uoaAbRnquP8DPUBSRF+ICnCDyRF+IGkCD+QFOEHkmKobwZYuXJlbe3JJ58srnv69OlifenSpW31dNZll11WWxsZGelo25gcQ30Aigg/kBThB5Ii/EBShB9IivADSRF+ICmm6J4B9u3bV1u78sor2153KlpdR3DkyJGOto/e4cgPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kxzj8NLF++vFg/dOhQbe3RRx8trrtkyZJivZqXodauXbuK9Y8//rhYR3M48gNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUi3H+W1vlbRG0omIWFYtu0DSjyUtljQqaV1E/LZ3beZ2++23F+urV6+urc2fP7+jfT///PPF+o4dOzraPpozlSP/NkmrPrfsfkm7I2KppN3VcwDTSMvwR8QLkt773OK1krZXj7dLurHLfQHosXY/8y+IiGPV43ckLehSPwD6pONr+yMiSnPw2R6WNNzpfgB0V7tH/uO2F0pS9ftE3QsjYktEDEXEUJv7AtAD7YZ/l6T11eP1kp7pTjsA+qVl+G0/Iek/JV1u+6jtOyU9Iuk6229K+ovqOYBpxBG1H9e7v7PCdwOoN2fOnGL9ueeeq61de+21xXU/+uijYn3FihXF+sjISLGO/ouI8h9hqHCFH5AU4QeSIvxAUoQfSIrwA0kRfiAp/nT3NNBqOG7//v21tVZDfeedd16xvmjRomKdob7piyM/kBThB5Ii/EBShB9IivADSRF+ICnCDyTFLb0zwIEDB2pry5Yt62jbY2NjxXqr6whGR0c72j/OHbf0Aigi/EBShB9IivADSRF+ICnCDyRF+IGkGOefAd59993a2rx584rrtpr++6abbirWW/2tgXvvvbe2xjUAvcE4P4Aiwg8kRfiBpAg/kBThB5Ii/EBShB9IquU4v+2tktZIOhERy6plmyT9laT/qV72YETUzxP9/9tinL8HTp48WVs7//zzi+teddVVxXqrv8u/cePGYn3dunW1tTVr1hTXfeONN4p1TK6b4/zbJK2aZPk/RsTy6qdl8AEMlpbhj4gXJL3Xh14A9FEnn/k32D5ge6vt8jWkAAZOu+H/rqQlkpZLOiZpc90LbQ/b3mt7b5v7AtADbYU/Io5HxCcR8amk70m6uvDaLRExFBFD7TYJoPvaCr/thROefkPSq91pB0C/tJyi2/YTkr4mab7to5I2Svqa7eWSQtKopG/2sEcAPcD9/DNAaZy/dK+/JF1++eXdbuczHnvssdraDTfcUFz3+uuvL9a5DmBy3M8PoIjwA0kRfiApwg8kRfiBpAg/kBRDfTPAIA/1zZ07t7a2c+fO4rqHDx8u1u+5555i/fTp08X6TMVQH4Aiwg8kRfiBpAg/kBThB5Ii/EBShB9IquX9/EAnTp06VVt78cUXi+s+9NBDxfoHH3xQrN93333FenYc+YGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcb5Z4DSePkVV1zRx066yy7fln7RRRf1qZOZiSM/kBThB5Ii/EBShB9IivADSRF+ICnCDyTVcpzf9iWSfiBpgaSQtCUiHrN9gaQfS1osaVTSuoj4be9aRZ0NGzbU1vbs2VNcd+XKlcX6vn372urprMWLF9fWbr311uK6/ZxTIqOpHPnPSPqbiPiqpD+W9C3bX5V0v6TdEbFU0u7qOYBpomX4I+JYRLxSPf5Q0uuSLpa0VtL26mXbJd3YqyYBdN85fea3vVjSCkl7JC2IiGNV6R2NfywAME1M+dp+23MlPSXp2xHxu4nXXUdE1M3DZ3tY0nCnjQLorikd+W1/SePB/2FEnJ1d8bjthVV9oaQTk60bEVsiYigihrrRMIDuaBl+jx/ivy/p9Yj4zoTSLknrq8frJT3T/fYA9MpUTvv/RNJfSjpoe3+17EFJj0j6F9t3SjoiaV1vWkQrb7/9dm3t2WefLa67adOmYr3VFN+jo6PF+l133VVbu/DCC4vrtnLw4MGO1s+uZfgj4j8k1d1Y/efdbQdAv3CFH5AU4QeSIvxAUoQfSIrwA0kRfiAp9/O2ybpLgNGcBx54oFjfuHFjsT579uxutvMZL730UrG+atWqYv3999/vZjvTRkSU/+Z5hSM/kBThB5Ii/EBShB9IivADSRF+ICnCDyTFFN3JPfzww8X6mTNnivW77767WJ87d25tbceOHcV1H3/88WI96zh+t3DkB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkuJ8fmGG4nx9AEeEHkiL8QFKEH0iK8ANJEX4gKcIPJNUy/LYvsf3vtl+zfcj2X1fLN9kes72/+lnd+3YBdEvLi3xsL5S0MCJesf1lSfsk3ShpnaRTEfHolHfGRT5Az031Ip+Wf8knIo5JOlY9/tD265Iu7qw9AE07p8/8thdLWiFpT7Vog+0DtrfanlezzrDtvbb3dtQpgK6a8rX9tudK+oWkv4+InbYXSDopKST9ncY/GtzRYhuc9gM9NtXT/imF3/aXJP1E0k8j4juT1BdL+klELGuxHcIP9FjXbuyxbUnfl/T6xOBXXwSe9Q1Jr55rkwCaM5Vv+6+R9EtJByV9Wi1+UNLNkpZr/LR/VNI3qy8HS9viyA/0WFdP+7uF8AO9x/38AIoIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSbX8A55ddlLSkQnP51fLBtGg9jaofUn01q5u9nbpVF/Y1/v5v7Bze29EDDXWQMGg9jaofUn01q6meuO0H0iK8ANJNR3+LQ3vv2RQexvUviR6a1cjvTX6mR9Ac5o+8gNoSCPht73K9q9sj9i+v4ke6tgetX2wmnm40SnGqmnQTth+dcKyC2z/3Pab1e9Jp0lrqLeBmLm5MLN0o+/doM143ffTftuzJP1a0nWSjkp6WdLNEfFaXxupYXtU0lBEND4mbPtPJZ2S9IOzsyHZ/gdJ70XEI9V/nPMi4r4B6W2TznHm5h71Vjez9O1q8L3r5ozX3dDEkf9qSSMRcTgiTkv6kaS1DfQx8CLiBUnvfW7xWknbq8fbNf6Pp+9qehsIEXEsIl6pHn8o6ezM0o2+d4W+GtFE+C+W9JsJz49qsKb8Dkk/s73P9nDTzUxiwYSZkd6RtKDJZibRcubmfvrczNID8961M+N1t/GF3xddExF/JOl6Sd+qTm8HUox/Zhuk4ZrvSlqi8Wncjkna3GQz1czST0n6dkT8bmKtyfdukr4aed+aCP+YpEsmPF9ULRsIETFW/T4h6WmNf0wZJMfPTpJa/T7RcD//JyKOR8QnEfGppO+pwfeumln6KUk/jIid1eLG37vJ+mrqfWsi/C9LWmr7K7ZnS7pJ0q4G+vgC23OqL2Jke46kr2vwZh/eJWl99Xi9pGca7OUzBmXm5rqZpdXwezdwM15HRN9/JK3W+Df+/y3pb5vooaavP5T0X9XPoaZ7k/SExk8DP9b4dyN3SvoDSbslvSnp3yRdMEC9/ZPGZ3M+oPGgLWyot2s0fkp/QNL+6md10+9doa9G3jeu8AOS4gs/ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJ/S8qhG2k004odAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "prediction is 8, label is 8, matched: True\n", "prediction is 9, label is 9, matched: True\n", "prediction is 9, label is 9, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 7, label is 7, matched: True\n", "prediction is 8, label is 8, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 5, label is 0, matched: False\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADlBJREFUeJzt3X+MFHWax/HPI8tOEFBRcmR0ddjboBHRwDExp2cI5wniuBFJjIKJzkXibAwaNzmTI56JPy4Xzcli/EM3ghDYc8/dU9lINsZdjvgzXDaMhkPE24XbDAEyMho3LkbjnvDcH1NzN+LUt5ru6q4envcrmUx3P13VDx0+U9X9raqvubsAxHNa1Q0AqAbhB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8Q1Lda+WJmxuGEQJO5u9XyvIa2/Ga2xMx+a2b7zWx1I+sC0FpW77H9ZjZB0u8kLZJ0SNJOSSvcfW9iGbb8QJO1Yst/uaT97v57d/+TpJ9JWtrA+gC0UCPhP0/SwVH3D2WPfY2Z9ZlZv5n1N/BaAErW9C/83H2dpHUSu/1AO2lky39Y0vmj7n8newzAONBI+HdKmmVm3zWzb0taLmlrOW0BaLa6d/vd/Sszu1vSryRNkLTR3d8vrTMATVX3UF9dL8ZnfqDpWnKQD4Dxi/ADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGg6p6iW5LMbEDSUUnHJH3l7t1lNAWg+RoKf+av3f3jEtYDoIXY7QeCajT8LunXZvaOmfWV0RCA1mh0t/8qdz9sZn8maZuZ/Ze7vzn6CdkfBf4wAG3G3L2cFZk9JOkzd1+TeE45LwYgl7tbLc+re7ffzCab2dSR25IWS9pT7/oAtFYju/0zJP3CzEbW86/u/mopXQFoutJ2+2t6MXb7W+7CCy9M1idNmpSsDw4OJutDQ0Mn3ROaq+m7/QDGN8IPBEX4gaAIPxAU4QeCIvxAUGWc1YcmmzNnTrJ+xRVX5NYef/zx5LJTpkxJ1nfu3Fn3a6O9seUHgiL8QFCEHwiK8ANBEX4gKMIPBEX4gaAY528D8+fPT9afeOKJZP3KK68ss52v6erqatq6US22/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFJfuboHu7vTM5a+99lqyXnR57TVrcidJ0tq1a5PL7t69O1l/+umnk/Wi8/0vueSS3NrmzZuTy3700UfJOsbGpbsBJBF+ICjCDwRF+IGgCD8QFOEHgiL8QFCF4/xmtlHS9yUNufuc7LGzJf1c0kxJA5Judvc/FL7YKTrOX3Q+/uuvv56sF107/+23307Wr7nmmtzatGnTkstu2LAhWe/p6UnWjx8/nqw30+eff56s33DDDbm1HTt2JJf98ssv6+qpHZQ5zr9J0pITHlstabu7z5K0PbsPYBwpDL+7vynpkxMeXipp5PCszZJuLLkvAE1W72f+Ge4+mN3+UNKMkvoB0CINX8PP3T31Wd7M+iT1Nfo6AMpV75b/iJl1SlL2eyjvie6+zt273T19dguAlqo3/Fsl9Wa3eyW9XE47AFqlMPxm9ryk/5B0kZkdMrOVkh6TtMjM9km6JrsPYBzhfP4aTZ06NbdWdF56arxZko4dO5as33rrrcn69OnTc2tPPfVUctkiZukh41b+/zlRI71t3Lgxuey9996brH/xxRfJepU4nx9AEuEHgiL8QFCEHwiK8ANBEX4gKKbortHy5ctza0VDeUUWL16crJ955pnJeiPDeW+88Uay3t/fn6wvW7YsWV+1alVubfbs2clllyw58WTSr+vo6EjWFyxYkFu74447kssWueeee5L18XBKMFt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKU3ozRZe4Tl1+OzUNtSTt3bs3Wb/sssuS9aJThq+++urc2u23355c9q233krWi06b7erqStb379+frDdi4sSJyXpqnH/Tpk3JZTs7O5P1RYsWJetF0643E6f0Akgi/EBQhB8IivADQRF+ICjCDwRF+IGgGOfPrFy5Mll/5pln6l73ihUrkvUXXnghWU9dS0BKT/H97LPPJpeNau7cucn6q6++mqyffvrpyfoZZ5xx0j2VhXF+AEmEHwiK8ANBEX4gKMIPBEX4gaAIPxBU4Ti/mW2U9H1JQ+4+J3vsIUl3Svooe9r97v5K4YtVOM6fmmJbknbs2JGsp64x/9xzzyWX7e3tTdbRfnbt2pWsX3rppcn6hAkTymznpJQ5zr9J0lizJzzh7nOzn8LgA2gvheF39zclfdKCXgC0UCOf+e82s91mttHM0tfAAtB26g3/jyV9T9JcSYOSfpT3RDPrM7N+M0tP+gagpeoKv7sfcfdj7n5c0npJlyeeu87du929u94mAZSvrvCb2ehLmy6TtKecdgC0SuEU3Wb2vKSFkqab2SFJD0paaGZzJbmkAUk/aGKPAJqgMPzuPtbJ6Bua0EtTTZ48OVm/+OKLk/XU8RCvvMJI53hTdNxHR0dHst7K62A0C0f4AUERfiAowg8ERfiBoAg/EBThB4IqHOo7Vdx0000NLb9ly5bcGkN948+aNWuS9VmzZrWok+qw5QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoMKM81977bUNLb9+/frc2tGjRxtaN5pj2bJlubVbbrmloXWvXr26oeXbAVt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwgqzDh/T09Psn4qXIo5mgcffDBZv++++3JrkyZNSi774osvJutr165N1scDtvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EJQVjW+b2fmSfiJphiSXtM7dnzSzsyX9XNJMSQOSbnb3PxSsq7LB9KJ/565du5L1efPmldkOJC1YsCBZf/jhh5P1hQsXJuvHjx/PrT3wwAPJZR999NFkvZ25u9XyvFq2/F9J+jt3ny3pLyWtMrPZklZL2u7usyRtz+4DGCcKw+/ug+7+bnb7qKQPJJ0naamkzdnTNku6sVlNAijfSX3mN7OZkuZJ+o2kGe4+mJU+1PDHAgDjRM3H9pvZFEkvSfqhu//R7P8/Vri7532eN7M+SX2NNgqgXDVt+c1sooaD/1N3H5mx8oiZdWb1TklDYy3r7uvcvdvdu8toGEA5CsNvw5v4DZI+cPfRpzJtldSb3e6V9HL57QFollp2+/9K0m2S3jOzkfGw+yU9JunfzGylpAOSbm5Oi+U4cuRIst7R0ZGsT506Nbd2Kl+6+9xzz03WZ8+enazfdtttubXrr78+uexZZ52VrO/bty9Zf+SRR3Jr27dvTy4bQWH43f1tSXnjhn9TbjsAWoUj/ICgCD8QFOEHgiL8QFCEHwiK8ANBFZ7SW+qLVXhK76pVq5L1J598Mlm/7rrrcmvbtm2rq6eynHPOObm1+fPnJ5ctqt95553J+gUXXJCspwwODibre/bsSdbvuuuuZH1gYOBkWzollHlKL4BTEOEHgiL8QFCEHwiK8ANBEX4gKMIPBBVmnH/ixInJ+kUXXZSsHzhwILdW9fn8qWsNdHV1NbTu005Lbx9Sl8cu8umnnybrBw8erHvdkTHODyCJ8ANBEX4gKMIPBEX4gaAIPxAU4QeCCjPOD0TBOD+AJMIPBEX4gaAIPxAU4QeCIvxAUIQfCKow/GZ2vpm9ZmZ7zex9M7s3e/whMztsZruyn57mtwugLIUH+ZhZp6ROd3/XzKZKekfSjZJulvSZu6+p+cU4yAdouloP8vlWDSsalDSY3T5qZh9IOq+x9gBU7aQ+85vZTEnzJP0me+huM9ttZhvNbFrOMn1m1m9m/Q11CqBUNR/bb2ZTJL0h6Z/cfYuZzZD0sSSX9I8a/mhwR8E62O0HmqzW3f6awm9mEyX9UtKv3H3tGPWZkn7p7nMK1kP4gSYr7cQeMzNJGyR9MDr42ReBI5ZJSk+pCqCt1PJt/1WS3pL0nqSR6zTfL2mFpLka3u0fkPSD7MvB1LrY8gNNVupuf1kIP9B8nM8PIInwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QVOEFPEv2saQDo+5Pzx5rR+3aW7v2JdFbvcrsravWJ7b0fP5vvLhZv7t3V9ZAQrv21q59SfRWr6p6Y7cfCIrwA0FVHf51Fb9+Srv21q59SfRWr0p6q/QzP4DqVL3lB1CRSsJvZkvM7Ldmtt/MVlfRQx4zGzCz97KZhyudYiybBm3IzPaMeuxsM9tmZvuy32NOk1ZRb20xc3NiZulK37t2m/G65bv9ZjZB0u8kLZJ0SNJOSSvcfW9LG8lhZgOSut298jFhM1sg6TNJPxmZDcnM/lnSJ+7+WPaHc5q7/32b9PaQTnLm5ib1ljez9N+qwveuzBmvy1DFlv9ySfvd/ffu/idJP5O0tII+2p67vynpkxMeXippc3Z7s4b/87RcTm9twd0H3f3d7PZRSSMzS1f63iX6qkQV4T9P0sFR9w+pvab8dkm/NrN3zKyv6mbGMGPUzEgfSppRZTNjKJy5uZVOmFm6bd67ema8Lhtf+H3TVe7+F5Kuk7Qq271tSz78ma2dhmt+LOl7Gp7GbVDSj6psJptZ+iVJP3T3P46uVfnejdFXJe9bFeE/LOn8Ufe/kz3WFtz9cPZ7SNIvNPwxpZ0cGZkkNfs9VHE//8fdj7j7MXc/Lmm9KnzvspmlX5L0U3ffkj1c+Xs3Vl9VvW9VhH+npFlm9l0z+7ak5ZK2VtDHN5jZ5OyLGJnZZEmL1X6zD2+V1Jvd7pX0coW9fE27zNycN7O0Kn7v2m7Ga3dv+Y+kHg1/4//fkv6hih5y+vpzSf+Z/bxfdW+SntfwbuD/aPi7kZWSzpG0XdI+Sf8u6ew26u1fNDyb824NB62zot6u0vAu/W5Ju7Kfnqrfu0RflbxvHOEHBMUXfkBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgvpfkSSOOSvR22EAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "prediction is 0, label is 0, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 8, label is 8, matched: True\n", "prediction is 5, label is 5, matched: True\n", "prediction is 7, label is 7, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 7, label is 7, matched: True\n", "prediction is 5, label is 5, matched: True\n", "prediction is 5, label is 5, matched: True\n", "prediction is 9, label is 9, matched: True\n", "prediction is 9, label is 9, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 2, label is 2, matched: True\n", "prediction is 5, label is 5, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 0, label is 7, matched: False\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADOpJREFUeJzt3W+oHXV+x/HPx5gguCFoF0PIpmaNUVj/plykQihb2yxpXI15EiIIt3Tp3QcrNNAH/ntQiRSWupvSB7KQxZBsSd1VYkxca3e3odQtyGoMqfmjiTZk2YSYq0ZcRWRN8u2DO2nv6j2/uTlnzplz832/4HLPme+ZmS/D/dyZOXPm/BwRApDPJW03AKAdhB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKXDnJltvk4IdBnEeHpvK6nPb/tlbYP237b9oO9LAvAYLnbz/bbniXpiKQVko5LelXSvRFxqDAPe36gzwax579N0tsRcTQififpx5JW97A8AAPUS/gXSvrNpOfHq2m/x/aY7T229/SwLgAN6/sbfhGxSdImicN+YJj0suc/IWnRpOdfqaYBmAF6Cf+rkpba/qrtOZLWSdrVTFsA+q3rw/6IOGP7fkk/kzRL0uaIONhYZwD6qutLfV2tjHN+oO8G8iEfADMX4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0l1PUS3JNk+JukjSWclnYmIkSaaAtB/PYW/8qcR8V4DywEwQBz2A0n1Gv6Q9HPbr9kea6IhAIPR62H/8og4YfsqSb+w/WZEvDT5BdU/Bf4xAEPGEdHMguxHJX0cEd8rvKaZlQHoKCI8ndd1fdhv+3Lbc88/lvQNSQe6XR6AwerlsH++pB22zy/nXyLi3xrpCkDfNXbYP62VcdgP9F3fD/sBzGyEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQurXuB7c2SvilpPCJurKZdKeknkhZLOiZpbUR80L82Z7Y5c+YU69dee21Py7/ppps61pYuXVqc96677irWZ8+e3VVPTThy5EixfuDAgWL9xIkTHWv79+8vzrtv375i/cyZM8X6TDCdPf8WSSs/N+1BSbsjYqmk3dVzADNIbfgj4iVJpz83ebWkrdXjrZLuabgvAH3W7Tn//Ig4WT1+R9L8hvoBMCC15/x1IiJsR6e67TFJY72uB0Czut3zn7K9QJKq3+OdXhgRmyJiJCJGulwXgD7oNvy7JI1Wj0cl7WymHQCDUht+209JelnS9baP2/6WpO9KWmH7LUl/Xj0HMIM4ouPpevMrK7w3MOxGRjqftaxbt6447+rVq4v1JUuWFOu2i/XS9e5XXnmlOO/Ro0eL9eeee65YP3v2bLFectlllxXrd955Z9fLlqR58+Z1rN19993Fed99991i/emnny7Wn3jiiWL93LlzxXovIqL8B1PhE35AUoQfSIrwA0kRfiApwg8kRfiBpLjUV1mxYkWx/sgjj3SsLVu2rDjvhg0bivU333yzWD948GCxXrp19bPPPivOm1XdZcabb765WN++fXux/vLLLxfra9euLdZ7waU+AEWEH0iK8ANJEX4gKcIPJEX4gaQIP5BUz1/jdbF47LHHivWdOzt/X8kNN9xQnPfDDz8s1l944YViHc379NNPi/W6W6HXr19frD/zzDMX3NOgsecHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaS4zl95/vnni/VFixZ1rI2NlUcje/zxx4v1999/v1jfsWNHsY7Be+CBB4r1mfDZDfb8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5BU7ff2294s6ZuSxiPixmrao5L+WtL5cYwfjoh/rV3ZEH9v/3XXXVesHzp0qGNtdHS0OO/ChQuL9TvuuKNYX7lyZbGO5tUNwb18+fJi/ZprrinW675PoBdNfm//FklT/fX9Y0TcWv3UBh/AcKkNf0S8JOn0AHoBMEC9nPPfb/t125ttX9FYRwAGotvw/0DSEkm3Sjop6fudXmh7zPYe23u6XBeAPugq/BFxKiLORsQ5ST+UdFvhtZsiYiQiRrptEkDzugq/7QWTnq6RdKCZdgAMSu0tvbafkvR1SV+2fVzS30n6uu1bJYWkY5K+3cceAfRBbfgj4t4pJj/Zh15adeTIkWJ91apVHWt1927v3bu3WB8fHy/W0Z1LLul8YLtx48bivGvWrCnWb7nllmK9n9fxm8In/ICkCD+QFOEHkiL8QFKEH0iK8ANJ1d7S2+jKhviW3l5cffXVxXrd7Z/btm1rsp00rrrqqmJ9y5YtHWvXX399cd7SpV1JOnz4cLHepiZv6QVwESL8QFKEH0iK8ANJEX4gKcIPJEX4gaS4zo+hVXcd/8UXXyzW586d27F2++23F+etGzZ9mHGdH0AR4QeSIvxAUoQfSIrwA0kRfiApwg8kVfvV3UC/9Hod/4MPPijWS/fkz+Tr+E1hzw8kRfiBpAg/kBThB5Ii/EBShB9IivADSdXez297kaQfSZovKSRtioh/sn2lpJ9IWizpmKS1EVG88Mr9/PnMmjWrY61uaPNPPvmkWL/vvvt6mv9i1eT9/Gck/W1EfE3SH0v6ju2vSXpQ0u6IWCppd/UcwAxRG/6IOBkRe6vHH0l6Q9JCSaslba1etlXSPf1qEkDzLuic3/ZiScsk/UrS/Ig4WZXe0cRpAYAZYtqf7bf9JUnbJa2PiN/a/39aERHR6Xze9piksV4bBdCsae35bc/WRPC3RcSz1eRTthdU9QWSxqeaNyI2RcRIRIw00TCAZtSG3xO7+CclvRERGyeVdkkarR6PStrZfHsA+mU6l/qWS/qlpP2SzlWTH9bEef/Tkv5Q0q81canvdM2yuNSXzEMPPdSxNm/evOK8GzZsKNazXsqrM91LfbXn/BHxX5I6LezPLqQpAMODT/gBSRF+ICnCDyRF+IGkCD+QFOEHkmKIbuAiwxDdAIoIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gqdrw215k+z9sH7J90PbfVNMftX3C9r7qZ1X/2wXQlNpBO2wvkLQgIvbanivpNUn3SFor6eOI+N60V8agHUDfTXfQjkunsaCTkk5Wjz+y/Yakhb21B6BtF3TOb3uxpGWSflVNut/267Y3276iwzxjtvfY3tNTpwAaNe2x+mx/SdJ/Svr7iHjW9nxJ70kKSY9p4tTgr2qWwWE/0GfTPeyfVvhtz5b0U0k/i4iNU9QXS/ppRNxYsxzCD/RZYwN12rakJyW9MTn41RuB562RdOBCmwTQnum8279c0i8l7Zd0rpr8sKR7Jd2qicP+Y5K+Xb05WFoWe36gzxo97G8K4Qf6r7HDfgAXJ8IPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBStV/g2bD3JP160vMvV9OG0bD2Nqx9SfTWrSZ7u3q6Lxzo/fxfWLm9JyJGWmugYFh7G9a+JHrrVlu9cdgPJEX4gaTaDv+mltdfMqy9DWtfEr11q5XeWj3nB9Cetvf8AFrSSvhtr7R92Pbbth9so4dObB+zvb8aebjVIcaqYdDGbR+YNO1K27+w/Vb1e8ph0lrqbShGbi6MLN3qthu2Ea8Hfthve5akI5JWSDou6VVJ90bEoYE20oHtY5JGIqL1a8K2/0TSx5J+dH40JNv/IOl0RHy3+sd5RUQ8MCS9PaoLHLm5T711Gln6L9XitmtyxOsmtLHnv03S2xFxNCJ+J+nHkla30MfQi4iXJJ3+3OTVkrZWj7dq4o9n4Dr0NhQi4mRE7K0efyTp/MjSrW67Ql+taCP8CyX9ZtLz4xquIb9D0s9tv2Z7rO1mpjB/0shI70ia32YzU6gduXmQPjey9NBsu25GvG4ab/h90fKI+CNJfyHpO9Xh7VCKiXO2Ybpc8wNJSzQxjNtJSd9vs5lqZOntktZHxG8n19rcdlP01cp2ayP8JyQtmvT8K9W0oRARJ6rf45J2aOI0ZZicOj9IavV7vOV+/k9EnIqIsxFxTtIP1eK2q0aW3i5pW0Q8W01ufdtN1Vdb262N8L8qaantr9qeI2mdpF0t9PEFti+v3oiR7cslfUPDN/rwLkmj1eNRSTtb7OX3DMvIzZ1GllbL227oRryOiIH/SFqliXf8/0fSI2300KGvayT9d/VzsO3eJD2licPAzzTx3si3JP2BpN2S3pL075KuHKLe/lkTozm/romgLWipt+WaOKR/XdK+6mdV29uu0Fcr241P+AFJ8YYfkBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGk/hdh8D6jhPHYCwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "prediction is 4, label is 4, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 2, label is 2, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 8, label is 8, matched: True\n", "prediction is 5, label is 5, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 5, label is 5, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 8, label is 8, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 7, label is 7, matched: True\n", "prediction is 1, label is 1, matched: True\n", "prediction is 4, label is 4, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 9, label is 9, matched: True\n", "prediction is 2, label is 2, matched: True\n", "prediction is 2, label is 2, matched: True\n", "prediction is 0, label is 0, matched: True\n", "prediction is 3, label is 3, matched: True\n", "prediction is 6, label is 6, matched: True\n", "prediction is 6, label is 6, matched: True\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "predictions = predictor.predict(train_data[:100])\n", "for i in range(0, 100):\n", " prediction = predictions['predictions'][i]['classes']\n", " label = train_labels[i]\n", " print('prediction is {}, label is {}, matched: {}'.format(prediction, label, prediction == label))\n", " if (prediction != label):\n", " plotData = train_data[i]\n", " plotData = plotData.reshape(28, 28)\n", " plt.gray() # use this line if you don't want to see it in color\n", " plt.imshow(plotData)\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Delete the endpoint\n", "\n", "Let's delete the endpoint we just created to prevent incurring any extra costs." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "#%%time\n", "#sagemaker.Session().delete_endpoint(predictor.endpoint)" ] } ], "metadata": { "kernelspec": { "display_name": "conda_tensorflow_p36", "language": "python", "name": "conda_tensorflow_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 2 }