{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Machine Translation English-German Example Using SageMaker Seq2Seq\n", "\n", "1. [Introduction](#Introduction)\n", "2. [Setup](#Setup)\n", "3. [Download dataset and preprocess](#Download-dataset-and-preprocess)\n", "3. [Training the Machine Translation model](#Training-the-Machine-Translation-model)\n", "4. [Inference](#Inference)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "Welcome to our Machine Translation end-to-end example! In this demo, we will use a pre-trained English-German translation model and will deploy it for an internet-facing App. \n", "\n", "SageMaker Seq2Seq algorithm is built on top of [Sockeye](https://github.com/awslabs/sockeye), a sequence-to-sequence framework for Neural Machine Translation based on MXNet. SageMaker Seq2Seq implements state-of-the-art encoder-decoder architectures which can also be used for tasks like Abstractive Summarization in addition to Machine Translation.\n", "\n", "To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "\n", "Let's start by specifying:\n", "- The S3 bucket and prefix that you want to use for training and model data. **This should be within the same region as the Notebook Instance, training, and hosting.**\n", "- The IAM role arn used to give training and hosting access to your data. See the documentation for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the boto regexp in the cell below with a the appropriate full IAM role arn string(s).\n", "\n", "### Lab time\n", "This notebook will take about 12 to 15 minutes to complete." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import timeit\n", "start_time = timeit.default_timer()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "isConfigCell": true }, "outputs": [], "source": [ "import boto3\n", "import re\n", "import sagemaker\n", "from sagemaker import get_execution_role\n", "\n", "role = get_execution_role()\n", "\n", "# S3 bucket and prefix\n", "bucket = '' # replace with an existing bucket if needed\n", "prefix = 'sagemaker/seq2seq/eng-german' # E.g.'sagemaker/seq2seq/eng-german'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we'll import the Python libraries we'll need for the remainder of the exercise." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from time import gmtime, strftime\n", "import time\n", "import numpy as np\n", "import os\n", "import json\n", "\n", "# For plotting attention matrix later on\n", "import matplotlib\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def upload_to_s3(bucket, prefix, channel, file):\n", " s3 = boto3.resource('s3')\n", " data = open(file, \"rb\")\n", " key = prefix + \"/\" + channel + '/' + os.path.basename(file)\n", " s3.Bucket(bucket).put_object(Key=key, Body=data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "region_name = boto3.Session().region_name" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.amazon.amazon_estimator import get_image_uri\n", "container = get_image_uri(boto3.Session().region_name, 'seq2seq', \"latest\")\n", "\n", "print('Using SageMaker Seq2Seq container: {} ({})'.format(container, region_name))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Use a pretrained model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "use_pretrained_model = True\n", "model_name = \"pretrained-en-de-model\"\n", "!curl https://s3.ap-northeast-2.amazonaws.com/pilho-immersionday-public-material/download/model.tar.gz > /tmp/model.tar.gz\n", "!curl https://s3.ap-northeast-2.amazonaws.com/pilho-immersionday-public-material/download/vocab.src.json > /tmp/vocab.src.json\n", "!curl https://s3.ap-northeast-2.amazonaws.com/pilho-immersionday-public-material/download/vocab.trg.json > /tmp/vocab.trg.json" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "upload_to_s3(bucket, prefix, 'pretrained_model', '/tmp/model.tar.gz')\n", "model_data = \"s3://{}/{}/pretrained_model/model.tar.gz\".format(bucket, prefix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "sage = boto3.client('sagemaker')\n", "\n", "if not use_pretrained_model:\n", " info = sage.describe_training_job(TrainingJobName=job_name)\n", " model_name=job_name\n", " model_data = info['ModelArtifacts']['S3ModelArtifacts']\n", "\n", "print(model_name)\n", "print(model_data)\n", "\n", "primary_container = {\n", " 'Image': container,\n", " 'ModelDataUrl': model_data\n", "}\n", "\n", "create_model_response = sage.create_model(\n", " ModelName = model_name,\n", " ExecutionRoleArn = role,\n", " PrimaryContainer = primary_container)\n", "\n", "print(create_model_response['ModelArn'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create endpoint configuration\n", "Use the model to create an endpoint configuration. The endpoint configuration also contains information about the type and number of EC2 instances to use when hosting the model.\n", "\n", "Since SageMaker Seq2Seq is based on Neural Nets, we could use an ml.p2.xlarge (GPU) instance, but for this example we will use a free tier eligible ml.m4.xlarge." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from time import gmtime, strftime\n", "\n", "endpoint_config_name = 'Seq2SeqEndpointConfig-' + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n", "print(endpoint_config_name)\n", "create_endpoint_config_response = sage.create_endpoint_config(\n", " EndpointConfigName = endpoint_config_name,\n", " ProductionVariants=[{\n", " 'InstanceType':'ml.m4.xlarge',\n", " 'InitialInstanceCount':1,\n", " 'ModelName':model_name,\n", " 'VariantName':'AllTraffic'}])\n", "\n", "print(\"Endpoint Config Arn: \" + create_endpoint_config_response['EndpointConfigArn'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create endpoint\n", "Lastly, we create the endpoint that serves up model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into production applications. This takes 10-15 minutes to complete." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "import time\n", "\n", "endpoint_name = 'Seq2SeqEndpoint-' + strftime(\"%Y-%m-%d-%H-%M-%S\", gmtime())\n", "print(endpoint_name)\n", "create_endpoint_response = sage.create_endpoint(\n", " EndpointName=endpoint_name,\n", " EndpointConfigName=endpoint_config_name)\n", "print(create_endpoint_response['EndpointArn'])\n", "\n", "resp = sage.describe_endpoint(EndpointName=endpoint_name)\n", "status = resp['EndpointStatus']\n", "print(\"Status: \" + status)\n", "\n", "# wait until the status has changed\n", "sage.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)\n", "\n", "# print the status of the endpoint\n", "endpoint_response = sage.describe_endpoint(EndpointName=endpoint_name)\n", "status = endpoint_response['EndpointStatus']\n", "print('Endpoint creation ended with EndpointStatus = {}'.format(status))\n", "\n", "if status != 'InService':\n", " raise Exception('Endpoint creation failed.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you see the message,\n", "> Endpoint creation ended with EndpointStatus = InService\n", "\n", "then congratulations! You now have a functioning inference endpoint. You can confirm the endpoint configuration and status by navigating to the \"Endpoints\" tab in the AWS SageMaker console. \n", "\n", "We will finally create a runtime object from which we can invoke the endpoint." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "runtime = boto3.client(service_name='runtime.sagemaker') " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Perform Inference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using JSON format for inference (Suggested for a single or small number of data instances)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Note that you don't have to convert string to text using the vocabulary mapping for inference using JSON mode" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sentences = [\"you are so good !\",\n", " \"can you drive a car ?\",\n", " \"i want to watch a movie .\"\n", " ]\n", "\n", "payload = {\"instances\" : []}\n", "for sent in sentences:\n", " payload[\"instances\"].append({\"data\" : sent})\n", "\n", "response = runtime.invoke_endpoint(EndpointName=endpoint_name, \n", " ContentType='application/json', \n", " Body=json.dumps(payload))\n", "\n", "response = response[\"Body\"].read().decode(\"utf-8\")\n", "response = json.loads(response)\n", "print(response)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Retrieving the Attention Matrix" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Passing `\"attention_matrix\":\"true\"` in `configuration` of the data instance will return the attention matrix." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sentence = 'can you drive a car ?'\n", "\n", "payload = {\"instances\" : [{\n", " \"data\" : sentence,\n", " \"configuration\" : {\"attention_matrix\":\"true\"}\n", " }\n", " ]}\n", "\n", "response = runtime.invoke_endpoint(EndpointName=endpoint_name, \n", " ContentType='application/json', \n", " Body=json.dumps(payload))\n", "\n", "response = response[\"Body\"].read().decode(\"utf-8\")\n", "response = json.loads(response)['predictions'][0]\n", "\n", "source = sentence\n", "target = response[\"target\"]\n", "attention_matrix = np.array(response[\"matrix\"])\n", "\n", "print(\"Source: %s \\nTarget: %s\" % (source, target))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define a function for plotting the attentioan matrix\n", "def plot_matrix(attention_matrix, target, source):\n", " source_tokens = source.split()\n", " target_tokens = target.split()\n", " assert attention_matrix.shape[0] == len(target_tokens)\n", " plt.imshow(attention_matrix.transpose(), interpolation=\"nearest\", cmap=\"Greys\")\n", " plt.xlabel(\"target\")\n", " plt.ylabel(\"source\")\n", " plt.gca().set_xticks([i for i in range(0, len(target_tokens))])\n", " plt.gca().set_yticks([i for i in range(0, len(source_tokens))])\n", " plt.gca().set_xticklabels(target_tokens)\n", " plt.gca().set_yticklabels(source_tokens)\n", " plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_matrix(attention_matrix, target, source)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Stop / Close the Endpoint (Optional)\n", "\n", "Finally, we should delete the endpoint before we close the notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#sage.delete_endpoint(EndpointName=endpoint_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# code you want to evaluate\n", "elapsed = timeit.default_timer() - start_time\n", "print(elapsed/60)" ] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.", "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 2 }