{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Word Pronunciation Example Using SageMaker (AWS-SDK) Seq2Seq\n", "\n", "1. [Introduction](#Introduction)\n", "2. [Setup](#Setup)\n", "3. [Download dataset and preprocess](#Download-dataset-and-preprocess)\n", "3. [Training the Word Pronunciation model](#Training-the-Word-Pronunciation-model)\n", "4. [Inference](#Inference)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "Welcome to our Word Pronunciation end-to-end example! In this demo, we will train an English Word Pronunciation model and will test the predictions on a few examples.\n", "\n", "SageMaker Seq2Seq algorithm is built on top of [Sockeye](https://github.com/awslabs/sockeye), a sequence-to-sequence framework for Neural Machine Translation based on MXNet. SageMaker Seq2Seq implements state-of-the-art encoder-decoder architectures which can also be used for tasks like Abstractive Summarization.\n", "\n", "SageMaker notebook has already provided you a sample Seq2seq that help you to build an English-Germany machine translation model based on a language data provided by [the Machine Translation Group at UEDIN](http://data.statmt.org/wmt17/translation-task/preprocessed/) (e.g. sample-notebooks/introduction_to_amazon_algorithms/seq2seq_translation_en-de). In this example, we are going to use Word-Pronunciation dataset provided by [CMUSphinx](https://cmusphinx.github.io/). \n", "\n", "To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "\n", "Let's start by specifying:\n", "- The S3 bucket and prefix that you want to use for training and model data. **This should be within the same region as the Notebook Instance, training, and hosting.**\n", "- The IAM role arn used to give training and hosting access to your data. See the documentation for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the boto regexp in the cell below with a the appropriate full IAM role arn string(s)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "isConfigCell": true }, "outputs": [], "source": [ "# S3 bucket and prefix\n", "bucket = ''\n", "prefix = 'seq2seq/word-pronunciation' \n", "# i.e.'/seq2seq/word-pronunciation'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import boto3\n", "import re\n", "from sagemaker import get_execution_role\n", "\n", "role = get_execution_role()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we'll import the Python libraries we'll need for the remainder of the exercise." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from time import gmtime, strftime\n", "import time\n", "import numpy as np\n", "import os\n", "import json\n", "import random\n", "\n", "# For plotting attention matrix later on\n", "import matplotlib\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download dataset and preprocess\n", "\n", "The most of the preprocessing script is borrowed from \n", "https://github.com/sunilmallya/dl-twitch-series/blob/master/E2_word_pronounciations.ipynb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we will train a word-pronunciation model on a dataset from the\n", "[CMUdict -- Major Version: 0.07](http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import urllib\n", "\n", "def download_data(url, force_download=True): \n", " fname = url.split(\"/\")[-1]\n", " if force_download or not os.path.exists(fname):\n", " urllib.request.urlretrieve(url, fname)\n", " return fname\n", "\n", "url_ds1 = \"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b\"\n", "fname = download_data(url_ds1) \n", "print(fname)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#!wget http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Load data\n", "data = open(fname, mode = 'rt', encoding = \"ISO-8859-1\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Generate words list and phones list\n", "words = []\n", "phones = []\n", "\n", "def f_char(word):\n", " for c in [\"(\", \".\", \"'\", \")\", \"-\", \"_\", \"\\xc0\", \"\\xc9\", ';']: ### added ;\n", " #print c in word, type(word)\n", " if c in word:\n", " return True\n", " return False\n", "\n", "i_count = 0\n", "\n", "for d in data:\n", " parts = d.strip('\\n').split(' ') \n", " #print(i_count)\n", " #i_count += 1\n", " #if not f_char(parts[0]):\n", " if re.match('^[A-Z]', parts[0]) and not f_char(parts[0]):\n", " words.append(parts[0])\n", " phones.append(parts[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's take a look at a word-phoneme pair. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 648\n", "print(words[idx])\n", "print(phones[idx])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(words), len(phones)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is the set of charactors in the entire dataset. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_chars = set()\n", "for word, phone in zip(words, phones):\n", " for c in word:\n", " all_chars.add(c)\n", " for p in phone.split(\" \"):\n", " all_chars.add(p)\n", " \n", "print(all_chars)\n", "print(len(all_chars))" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### Lets define some helper functions to convert words to symbols and vice versa" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a map of symbols to numbers\n", "symbol_set = sorted(list(all_chars))\n", "\n", "# word to symbol index\n", "def word_to_symbol_index(word):\n", " return [symbol_set.index(char) for char in word]\n", "\n", "# list of symbol index to word\n", "def symbol_index_to_word(indices):\n", " return [symbol_set[idx] for idx in indices]\n", "\n", "# phone to symbol index\n", "def phone_to_symbol_index(phone):\n", " return [symbol_set.index(p) for p in phone.split(\" \")]\n", "\n", "# list of symbol index to word\n", "def psymbol_index_to_word(indices):\n", " return [symbol_set[idx] for idx in indices]\n", "\n", "print(symbol_set)\n", "print(len(symbol_set))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tokenize words" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# sample word\n", "idx = 648\n", "indices_word = word_to_symbol_index(words[idx])\n", "print(indices_word, symbol_index_to_word(indices_word))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tokenize phonemes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# sample phone\n", "indices_phone = phone_to_symbol_index(phones[idx])\n", "print(indices_phone, symbol_index_to_word(indices_phone))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For any RNN task, it is important to keep track of the maximum length of input/output sequence. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# max_length\n", "source_sequence_length = max([len(w) for w in words])\n", "target_sequence_length = max([len(p.split(' ')) for p in phones])\n", "\n", "max_length = max(source_sequence_length, target_sequence_length)\n", "print(source_sequence_length, target_sequence_length, max_length)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's put together source data. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "### Source: Words\n", "dataX = []\n", "for word in words:\n", " dataX.append(np.array(word_to_symbol_index(word)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 648\n", "dataX[idx], symbol_index_to_word(dataX[idx])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's put together target data as well. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "### Target: Phonemes\n", "dataY =[]\n", "for p in phones:\n", " dataY.append(np.array(phone_to_symbol_index(p)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 648\n", "dataY[idx], symbol_index_to_word(dataY[idx])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(dataY), len(dataX)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 648\n", "\n", "print(\"SRC: \", symbol_index_to_word(dataX[idx]))\n", "print(\"TRG: \", symbol_index_to_word(dataY[idx])) \n", "print(\"SRC: \", dataX[idx])\n", "print(\"TRG: \", dataY[idx])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### Train Validation Split ###\n", "\n", "def shuffle_together(a, b):\n", " assert len(a) == len(b)\n", " p = np.random.permutation(len(a))\n", " return a[p], b[p]\n", "\n", "dataX, dataY = np.array(dataX), np.array(dataY)\n", "dataX, dataY = shuffle_together(dataX, dataY)\n", "\n", "print(dataX[:3])\n", "print(dataY[:3])\n", "\n", "print(dataX[:3] + 4)\n", "print(dataY[:3] + 4)\n", "\n", "N = int(len(dataX) * 0.9) # 90%\n", "\n", "### First 4 indices are saved for special characters ###\n", "\n", "trainX = dataX[:N] + 4\n", "trainY = dataY[:N] + 4\n", "\n", "print(dataX[:3])\n", "print(dataY[:3])\n", "print(trainX[:3])\n", "print(trainY[:3])\n", "\n", "valX = dataX[N:] + 4\n", "valY = dataY[N:] + 4\n", "\n", "print(dataX[:3])\n", "print(dataY[:3])\n", "\n", "print(type(trainX), type(trainX[0].tolist()), type(trainX[0].tolist()[0]))\n", "print(type(trainY), type(trainY[0].tolist()), type(trainX[0].tolist()[0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate vocabulary json files.\n", "\n", "Amazon SageMaker seq2seq requires two json \"vocabulary\" files. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "### First 4 indices are saved for special characters ###\n", "vocab_dict = {c:i + 4 for i,c in enumerate(symbol_set)}\n", "vocab_dict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add 4 special characters. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "PAD_SYMBOL = \"\" #0\n", "UNK_SYMBOL = \"\" #1\n", "BOS_SYMBOL = \"\" #2\n", "EOS_SYMBOL = \"\" #3\n", "\n", "VOCAB_SYMBOLS = [PAD_SYMBOL, UNK_SYMBOL, BOS_SYMBOL, EOS_SYMBOL]\n", "vocab_dict[PAD_SYMBOL] = 0\n", "vocab_dict[UNK_SYMBOL] = 1\n", "vocab_dict[BOS_SYMBOL] = 2\n", "vocab_dict[EOS_SYMBOL] = 3\n", "vocab_dict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, source and target data share the same vocabulary dataset. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import json\n", "with open('vocab.src.json', 'w') as fp:\n", " json.dump(vocab_dict, fp, indent=4, ensure_ascii=False)\n", " \n", "with open('vocab.trg.json', 'w') as fp:\n", " json.dump(vocab_dict, fp, indent=4, ensure_ascii=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate recordio-protobuf files.\n", "\n", "Amazon SageMaker expects data in the recordio-protobuf format (e.g. train.rec and val.rec). The function ``write_to_file`` generates a recordio-protobuf file from a stack of sequences using several helper functions from ``create_vocab_proto.py`` and ``record_pb2.py``. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import multiprocessing \n", "import logging\n", "\n", "from typing import List \n", "from record_pb2 import Record ### record_pb2.py\n", "from create_vocab_proto import write_worker, write_recordio, list_to_record_bytes, read_worker\n", "import struct\n", "import io\n", "\n", "logging.basicConfig(level=logging.INFO)\n", "logger = logging.getLogger(__name__)\n", "\n", " \n", "def write_to_file(np_dataX, np_dataY, file_type, output_file):\n", " num_read_workers = max(multiprocessing.cpu_count() - 1, 1) \n", " logger.info('Spawning %s encoding worker(s) for encoding %s datasets!', str(num_read_workers), file_type) \n", " \n", " q_in = [multiprocessing.Queue() for i in range(num_read_workers)] \n", " \n", " q_out = multiprocessing.Queue() \n", "\n", " read_process = [multiprocessing.Process(target=read_worker,\n", " args=(q_in[i], q_out)) for i in range(num_read_workers)] \n", " \n", " for p in read_process: \n", " p.start()\n", "\n", " write_process = multiprocessing.Process(target=write_worker, args=(q_out, output_file)) \n", " write_process.start() \n", " \n", " lines_ignored = 0 # No ignored lines in this example. \n", " lines_processed = 0\n", " \n", " for i, int_source in enumerate(np_dataX):\n", " int_source = int_source.tolist()\n", " int_target = np_dataY[i].tolist()\n", " item = (int_source, int_target) ### , \n", "\n", " if random.random() < 0.0001:\n", " ### Print some SRC-TRG pairs. \n", " print('=== === === === ===')\n", " print('SRC:', int_source)\n", " print(len(int_source), type(int_source), type(int_source[0])) # num \n", " print('--- --- --- --- ---')\n", " print('TRG:', int_target)\n", " print(len(int_target), type(int_target), type(int_target[0])) # num \n", "\n", " q_in[lines_processed % len(q_in)].put(item) \n", "\n", " lines_processed += 1 \n", " \n", " logger.info(\"\"\"Processed %s lines for encoding to protobuf. %s lines were ignored as they didn't have\n", " any content in either the source or the target file!\"\"\", lines_processed, lines_ignored)\n", " \n", " logger.info('Completed writing the encoding queue!')\n", "\n", " for q in q_in: \n", " q.put(None) \n", " for p in read_process: \n", " p.join()\n", " logger.info('Encoding finished! Writing records to \"%s\"', output_file)\n", " q_out.put(None) \n", " write_process.join() \n", " logger.info('Processed input and saved to \"%s\"', output_file)\n", " print('+++---+++---+++---+++---+++')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "file_type = 'train'\n", "output_file = \"train.rec\"\n", "write_to_file(trainX, trainY, file_type, output_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Validation Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "file_type = 'validation'\n", "output_file = \"val.rec\"\n", "write_to_file(valX, valY, file_type, output_file)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### Upload the files to S3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far we have the following 4 files. \n", "- train.rec : Contains source and target sequences for training in protobuf format\n", "- val.rec : Contains source and target sequences for validation in protobuf format\n", "- vocab.src.json : Vocabulary mapping (string to int) for source \n", "- vocab.trg.json : Vocabulary mapping (string to int) for target \n", "\n", "Let's upload the pre-processed dataset and vocabularies to S3" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def upload_to_s3(bucket, prefix, channel, file):\n", " s3 = boto3.resource('s3')\n", " data = open(file, \"rb\")\n", " key = prefix + \"/\" + channel + '/' + file\n", " s3.Bucket(bucket).put_object(Key=key, Body=data)\n", "\n", "upload_to_s3(bucket, prefix, 'train', 'train.rec') \n", "#//seq2seq/word-pronunciation/train/train.rec\n", "upload_to_s3(bucket, prefix, 'validation', 'val.rec') \n", "#//seq2seq/word-pronunciation/validation/val.rec \n", "upload_to_s3(bucket, prefix, 'vocab', 'vocab.src.json') \n", "#//seq2seq/word-pronunciation/vocab/vocab.src.json\n", "upload_to_s3(bucket, prefix, 'vocab', 'vocab.trg.json') \n", "#//seq2seq/word-pronunciation/vocab/vocab.trg.json" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Those files are uploaded to S3. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "region_name = boto3.Session().region_name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Container\n", "\n", "This is where the magic happens. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "containers = {'us-west-2': '433757028032.dkr.ecr.us-west-2.amazonaws.com/seq2seq:latest',\n", " 'us-east-1': '811284229777.dkr.ecr.us-east-1.amazonaws.com/seq2seq:latest',\n", " 'us-east-2': '825641698319.dkr.ecr.us-east-2.amazonaws.com/seq2seq:latest',\n", " 'eu-west-1': '685385470294.dkr.ecr.eu-west-1.amazonaws.com/seq2seq:latest'}\n", "container = containers[region_name]\n", "print('Using SageMaker Seq2Seq container: {} ({})'.format(container, region_name))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the Word Pronunciation model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "job_name = 'seq2seq-wrd-phn-p2-xlarge-' + strftime(\"%Y-%m-%d-%H-%M\", gmtime())\n", "print(\"Training job\", job_name)\n", "\n", "create_training_params = \\\n", "{\n", " \"AlgorithmSpecification\": {\n", " \"TrainingImage\": container,\n", " \"TrainingInputMode\": \"File\"\n", " },\n", " \"RoleArn\": role,\n", " \"OutputDataConfig\": {\n", " \"S3OutputPath\": \"s3://{}/{}/\".format(bucket, prefix)\n", " },\n", " \"ResourceConfig\": {\n", " # Seq2Seq does not support multiple machines. Currently, it only supports single machine, multiple GPUs\n", " \"InstanceCount\": 1,\n", " \"InstanceType\": \"ml.p2.xlarge\", # We suggest one of [\"ml.p2.16xlarge\", \"ml.p2.8xlarge\", \"ml.p2.xlarge\"]\n", " \"VolumeSizeInGB\": 50\n", " },\n", " \"TrainingJobName\": job_name,\n", " \"HyperParameters\": {\n", " # Please refer to the documentation for complete list of parameters\n", " \"max_seq_len_source\": str(source_sequence_length),\n", " \"max_seq_len_target\": str(target_sequence_length),\n", " \"optimized_metric\": \"bleu\", \n", " \"batch_size\": \"64\", # Please use a larger batch size (256 or 512) if using ml.p2.8xlarge or ml.p2.16xlarge\n", " \"checkpoint_frequency_num_batches\": \"1000\",\n", " \"rnn_num_hidden\": \"512\",\n", " \"num_layers_encoder\": \"1\",\n", " \"num_layers_decoder\": \"1\",\n", " \"num_embed_source\": \"512\",\n", " \"num_embed_target\": \"512\",\n", " \"checkpoint_threshold\": \"3\",\n", " #\"max_num_batches\": \"2100\"\n", " # Training will stop after 2100 iterations/batches.\n", " # This is just for demo purposes. Remove the above parameter if you want a better model.\n", " },\n", " \"StoppingCondition\": {\n", " \"MaxRuntimeInSeconds\": 48 * 3600\n", " },\n", " \"InputDataConfig\": [\n", " {\n", " \"ChannelName\": \"train\",\n", " \"DataSource\": {\n", " \"S3DataSource\": {\n", " \"S3DataType\": \"S3Prefix\",\n", " \"S3Uri\": \"s3://{}/{}/train/\".format(bucket, prefix),\n", " \"S3DataDistributionType\": \"FullyReplicated\"\n", " }\n", " },\n", " },\n", " {\n", " \"ChannelName\": \"vocab\",\n", " \"DataSource\": {\n", " \"S3DataSource\": {\n", " \"S3DataType\": \"S3Prefix\",\n", " \"S3Uri\": \"s3://{}/{}/vocab/\".format(bucket, prefix),\n", " \"S3DataDistributionType\": \"FullyReplicated\"\n", " }\n", " },\n", " },\n", " {\n", " \"ChannelName\": \"validation\",\n", " \"DataSource\": {\n", " \"S3DataSource\": {\n", " \"S3DataType\": \"S3Prefix\",\n", " \"S3Uri\": \"s3://{}/{}/validation/\".format(bucket, prefix),\n", " \"S3DataDistributionType\": \"FullyReplicated\"\n", " }\n", " },\n", " }\n", " ]\n", "}\n", "\n", "sagemaker_client = boto3.Session().client(service_name='sagemaker')\n", "sagemaker_client.create_training_job(**create_training_params)\n", "\n", "status = sagemaker_client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n", "print(status)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### Please keep on checking the status until this says \"Completed\". ###\n", "\n", "status = sagemaker_client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']\n", "print(status)\n", "# if the job failed, determine why\n", "if status == 'Failed':\n", " message = sagemaker_client.describe_training_job(TrainingJobName=job_name)['FailureReason']\n", " print('Training failed with the following error: {}'.format(message))\n", " raise Exception('Training job failed')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Now wait for the training job to **complete** and proceed to the next step after you see model artifacts in your S3 bucket.\n", "> If the cell above this returns **InProgress**, you still have to wait. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++---+++" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference\n", "\n", "A trained model does nothing on its own. We now want to use the model to perform inference. For this example, that means pronouncing word(s).\n", "This section involves several steps,\n", "- Create model - Create a model using the artifact (model.tar.gz) produced by training\n", "- Create Endpoint Configuration - Create a configuration defining an endpoint, using the above model\n", "- Create Endpoint - Use the configuration to create an inference endpoint.\n", "- Perform Inference - Perform inference on some input data using the endpoint.\n", "\n", "### Create model\n", "We now create a SageMaker Model from the training output. Using the model, we can then create an Endpoint Configuration." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "sage = boto3.client('sagemaker')\n", "\n", "info = sage.describe_training_job(TrainingJobName=job_name)\n", "model_name=job_name\n", "model_data = info['ModelArtifacts']['S3ModelArtifacts']\n", "\n", "print(model_name)\n", "print(model_data)\n", "\n", "primary_container = {\n", " 'Image': container,\n", " 'ModelDataUrl': model_data\n", "}\n", "\n", "create_model_response = sage.create_model(\n", " ModelName = model_name,\n", " ExecutionRoleArn = role,\n", " PrimaryContainer = primary_container)\n", "\n", "print(create_model_response['ModelArn'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create endpoint configuration\n", "Use the model to create an endpoint configuration. The endpoint configuration also contains information about the type and number of EC2 instances to use when hosting the model.\n", "\n", "Since SageMaker Seq2Seq is based on Neural Nets, we could use an ml.p2.xlarge (GPU) instance, but for this example we will use a free tier eligible ml.m4.xlarge." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from time import gmtime, strftime\n", "\n", "endpoint_config_name = 'Seq2SeqEndpointConfig-' + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n", "print(endpoint_config_name)\n", "create_endpoint_config_response = sage.create_endpoint_config(\n", " EndpointConfigName = endpoint_config_name,\n", " ProductionVariants=[{\n", " 'InstanceType':'ml.m4.xlarge', #####\n", " 'InitialInstanceCount':1,\n", " 'ModelName':model_name,\n", " 'VariantName':'AllTraffic'}])\n", "\n", "print(\"Endpoint Config Arn: \" + create_endpoint_config_response['EndpointConfigArn'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create endpoint\n", "Lastly, we create the endpoint that serves up model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes 10-15 minutes to complete." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "import time\n", "\n", "endpoint_name = 'Seq2SeqEndpoint-' + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n", "print(endpoint_name)\n", "create_endpoint_response = sage.create_endpoint(\n", " EndpointName=endpoint_name,\n", " EndpointConfigName=endpoint_config_name)\n", "print(create_endpoint_response['EndpointArn'])\n", "\n", "resp = sage.describe_endpoint(EndpointName=endpoint_name)\n", "status = resp['EndpointStatus']\n", "print(\"Status: \" + status)\n", "\n", "# wait until the status has changed\n", "sage.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)\n", "\n", "# print the status of the endpoint\n", "endpoint_response = sage.describe_endpoint(EndpointName=endpoint_name)\n", "status = endpoint_response['EndpointStatus']\n", "print('Endpoint creation ended with EndpointStatus = {}'.format(status))\n", "\n", "if status != 'InService':\n", " raise Exception('Endpoint creation failed.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you see the message,\n", "> Endpoint creation ended with EndpointStatus = InService\n", "\n", "then congratulations! You now have a functioning inference endpoint. You can confirm the endpoint configuration and status by navigating to the \"Endpoints\" tab in the AWS SageMaker console. \n", "\n", "We will finally create a runtime object from which we can invoke the endpoint." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "runtime = boto3.client(service_name='runtime.sagemaker') " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Perform Inference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using JSON format for inference (Suggested for a single or small number of data instances)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Note that you don't have to convert string to text using the vocabulary mapping for inference using JSON mode" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Making an input: \" \".join(list(word.upper())) \n", "word_infr = 'abcdefg'\n", "print(\" \".join(list(word_infr.upper())))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "words_infr = [\"car\",\n", " \"cat\",\n", " \"tapeworm\",\n", " \"tapdance\",\n", " \"supercalifragilistic\",\n", " \"expialidocious\"]\n", "\n", "payload = {\"instances\" : []}\n", "for word_infr in words_infr:\n", " \n", " payload[\"instances\"].append({\"data\" : \" \".join(list(word_infr.upper()))})\n", "\n", "response = runtime.invoke_endpoint(EndpointName=endpoint_name, \n", " ContentType='application/json', \n", " Body=json.dumps(payload))\n", "\n", "response = response[\"Body\"].read().decode(\"utf-8\")\n", "response = json.loads(response)\n", "print(response)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Retrieving the Attention Matrix" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Passing `\"attention_matrix\":\"true\"` in `configuration` of the data instance will return the attention matrix." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "word_infr = 'height'\n", "\n", "payload = {\"instances\" : [{\n", " \"data\" : \" \".join(list(word_infr.upper())),\n", " \"configuration\" : {\"attention_matrix\":\"true\"}\n", " }\n", " ]}\n", "\n", "response = runtime.invoke_endpoint(EndpointName=endpoint_name, \n", " ContentType='application/json', \n", " Body=json.dumps(payload))\n", "\n", "response = response[\"Body\"].read().decode(\"utf-8\")\n", "response = json.loads(response)['predictions'][0]\n", "\n", "source = \" \".join(list(word_infr.upper()))\n", "target = response[\"target\"]\n", "attention_matrix = np.array(response[\"matrix\"])\n", "\n", "print(\"Source: %s \\nTarget: %s\" % (source, target))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Define a function for plotting the attentioan matrix\n", "def plot_matrix(attention_matrix, target, source):\n", " source_tokens = source.split()\n", " target_tokens = target.split()\n", " assert attention_matrix.shape[0] == len(target_tokens)\n", " plt.imshow(attention_matrix.transpose(), interpolation=\"nearest\", cmap=\"Greys\")\n", " plt.xlabel(\"target\")\n", " plt.ylabel(\"source\")\n", " plt.gca().set_xticks([i for i in range(0, len(target_tokens))])\n", " plt.gca().set_yticks([i for i in range(0, len(source_tokens))])\n", " plt.gca().set_xticklabels(target_tokens)\n", " plt.gca().set_yticklabels(source_tokens)\n", " plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_matrix(attention_matrix, target, source)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### Using Protobuf format for inference (Suggested for efficient bulk inference)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reading the vocabulary mappings as this mode of inference accepts list of integers and returns list of integers." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import io\n", "import tempfile\n", "from record_pb2 import Record\n", "from create_vocab_proto import vocab_from_json, reverse_vocab, write_recordio, list_to_record_bytes, read_next\n", "\n", "source = vocab_from_json(\"vocab.src.json\")\n", "target = vocab_from_json(\"vocab.trg.json\")\n", "\n", "source_rev = reverse_vocab(source)\n", "target_rev = reverse_vocab(target)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "words_infr = [\"car\",\n", " \"cat\",\n", " \"tapeworm\",\n", " \"tapdance\",\n", " \"%\",\n", " \"345\",\n", " \"supercalifragilistic\",\n", " \"expialidocious\",\n", " \"Otorhinolaryngologist\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Converting the string to integers, followed by protobuf encoding:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Convert strings to integers using source vocab mapping. Out-of-vocabulary strings are mapped to 1 - the mapping for \n", "words_infr = [[source.get(token, 1) for token in \"\".join(list(word_infr.upper()))] for word_infr in words_infr]\n", "print(words_infr)\n", "\n", "f = io.BytesIO()\n", "for word_infr in words_infr:\n", " record = list_to_record_bytes(word_infr, [])\n", " write_recordio(f, record)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "f = io.BytesIO()\n", "for word_infr in words_infr:\n", " record = list_to_record_bytes(word_infr, [])\n", " write_recordio(f, record)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "response = runtime.invoke_endpoint(EndpointName=endpoint_name, \n", " ContentType='application/x-recordio-protobuf', \n", " Body=f.getvalue())\n", "\n", "response = response[\"Body\"].read()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, parse the protobuf response and convert list of integers back to strings" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def _parse_proto_response(received_bytes):\n", " output_file = tempfile.NamedTemporaryFile()\n", " output_file.write(received_bytes)\n", " output_file.flush()\n", " target_sentences = []\n", " with open(output_file.name, 'rb') as datum:\n", " next_record = True\n", " while next_record:\n", " next_record = read_next(datum)\n", " if next_record:\n", " rec = Record()\n", " rec.ParseFromString(next_record)\n", " target = list(rec.features[\"target\"].int32_tensor.values)\n", " target_sentences.append(target)\n", " else:\n", " break\n", " return target_sentences" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "targets = _parse_proto_response(response)\n", "resp = [\" \".join([target_rev.get(token, \"\") for token in phone_infr]) for\n", " phone_infr in targets]\n", "print(resp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Stop / Close the Endpoint (Optional)\n", "\n", "Finally, we should delete the endpoint before we close the notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sage.delete_endpoint(EndpointName=endpoint_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# End" ] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.2" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 2 }