{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fairseq in Amazon SageMaker: Translation task - English to French" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we will show you how to train an English to French translation model using a fully convolutional architecture using the [Fairseq toolkit](https://github.com/pytorch/fairseq)\n", "\n", "## Permissions\n", "\n", "Running this notebook requires permissions in addition to the regular SageMakerFullAccess permissions. This is because it creates new repositories in Amazon ECR. The easiest way to add these permissions is simply to add the managed policy AmazonEC2ContainerRegistryFullAccess to the role that you used to start your notebook instance. There's no need to restart your notebook instance when you do this, the new permissions will be available immediately." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare dataset\n", "\n", "To train the model, we will be using the WMT'14 dataset as descibed [here](https://github.com/pytorch/fairseq/tree/master/examples/translation#prepare-wmt14en2frsh). \n", "\n", "First, we'll download the dataset and start the pre-processing. Among other steps, this pre-processing cleans the tokens and applys BPE encoding as you can see [here](https://github.com/pytorch/fairseq/blob/master/examples/translation/prepare-wmt14en2fr.sh)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%sh\n", "cd data\n", "chmod +x prepare-wmt14en2fr.sh\n", "\n", "# Download dataset and start pre-processing\n", "./prepare-wmt14en2fr.sh" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next step is to apply the second set of pre-processing, which binarizes the dataset based on the source and target language. Full information on this script [here](https://github.com/pytorch/fairseq/blob/master/preprocess.py). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%sh\n", "\n", "# First we download fairseq in order to have access to the scripts\n", "git clone https://github.com/pytorch/fairseq.git fairseq-git\n", "cd fairseq-git\n", "\n", "# Binarize the dataset:\n", "TEXT=../data/wmt14_en_fr\n", "python preprocess.py --source-lang en --target-lang fr \\\n", " --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \\\n", " --destdir ../data/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset is now all prepared for training on one of the Fairseq translation models. The next step is upload the data to Amazon S3 in order to make it available for training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Upload data to Amazon S3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "\n", "sagemaker_session = sagemaker.Session()\n", "region = sagemaker_session.boto_session.region_name\n", "account = sagemaker_session.boto_session.client(\"sts\").get_caller_identity().get(\"Account\")\n", "\n", "bucket = sagemaker_session.default_bucket()\n", "prefix = \"sagemaker/DEMO-pytorch-fairseq/datasets/wmt14_en_fr\"\n", "\n", "role = sagemaker.get_execution_role()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = sagemaker_session.upload_data(path=\"data/wmt14_en_fr\", bucket=bucket, key_prefix=prefix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build Fairseq Translation task container" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we need to register a Docker image in Amazon SageMaker that will contain the Fairseq code and that will be pulled at training and inference time to perform the respective training of the model and the serving of the precitions. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%%sh\n", "chmod +x create_container.sh \n", "\n", "./create_container.sh pytorch-fairseq" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Fairseq image has been pushed into Amazon ECR, the registry from which Amazon SageMaker will be able to pull that image and launch both training and prediction. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training on Amazon SageMaker" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we will set the hyper-parameters of the model we want to train. Here we are using the recommended ones from the [Fairseq example](https://github.com/pytorch/fairseq/blob/master/examples/translation/prepare-wmt14en2fr.sh)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hyperparameters = {\n", " \"lr\": 0.5,\n", " \"clip-norm\": 0.1,\n", " \"dropout\": 0.1,\n", " \"max-tokens\": 3000,\n", " \"criterion\": \"label_smoothed_cross_entropy\",\n", " \"label-smoothing\": 0.1,\n", " \"lr-scheduler\": \"fixed\",\n", " \"force-anneal\": 50,\n", " \"arch\": \"fconv_wmt_en_fr\",\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are ready to define the Estimator, which will encapsulate all the required parameters needed for launching the training on Amazon SageMaker. For training, the Fairseq toolkit recommends to train on GPU instances, such as the `ml.p3` instance family [available in Amazon SageMaker](https://aws.amazon.com/sagemaker/pricing/instance-types/). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.estimator import Estimator\n", "\n", "algorithm_name = \"pytorch-fairseq\"\n", "image = \"{}.dkr.ecr.{}.amazonaws.com/{}:latest\".format(account, region, algorithm_name)\n", "\n", "estimator = Estimator(\n", " image,\n", " role,\n", " train_instance_count=1,\n", " train_instance_type=\"ml.p3.8xlarge\",\n", " output_path=\"s3://{}/output\".format(bucket),\n", " sagemaker_session=sagemaker_session,\n", " hyperparameters=hyperparameters,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The call to fit will launch the training job and regularly report on the different performance metrics such as losses. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "estimator.fit(inputs=inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model has finished training, we can go ahead and test its translation capabilities by deploying it on an endpoint.\n", "\n", "## Hosting the model\n", "\n", "We first need to define a base JSONPredictor class that will help us with sending predictions to the model once it's hosted on the Amazon SageMaker endpoint. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer\n", "\n", "\n", "class JSONPredictor(RealTimePredictor):\n", " def __init__(self, endpoint_name, sagemaker_session):\n", " super(JSONPredictor, self).__init__(\n", " endpoint_name, sagemaker_session, json_serializer, json_deserializer\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now use the estimator object to deploy the model artificats (the trained model), and deploy it on a CPU instance as we no longer need a GPU instance for simply infering from the model. Let's use a `ml.m5.xlarge`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor = estimator.deploy(\n", " initial_instance_count=1, instance_type=\"ml.m5.xlarge\", predictor_cls=JSONPredictor\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now it's your time to play. Input a sentence in English and get the translation in French by simply calling predict. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import html\n", "\n", "text_input = \"Hey, how you're doing?\"\n", "\n", "result = predictor.predict(text_input)\n", "# Some characters are escaped HTML-style requiring to unescape them before printing\n", "print(html.unescape(result))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once you're done with getting predictions, remember to shut down your endpoint as you no longer need it. \n", "\n", "## Delete endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sagemaker_session.delete_endpoint(predictor.endpoint)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Voila! For more information, you can check out the [Fairseq toolkit homepage](https://github.com/pytorch/fairseq). " ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/advanced_functionality|fairseq_translation|fairseq_sagemaker_translate_en2fr.ipynb)\n" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p36", "language": "python", "name": "conda_pytorch_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }