{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.\n", "#SPDX-License-Identifier: MIT-0" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "#install additional libraries\n", "!pip install nltk\n", "!pip install jsonlines\n", "!pip install pandarallel" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#import libraries\n", "import os\n", "import uuid\n", "import datetime\n", "import time\n", "import logging\n", "import glob\n", "\n", "import boto3\n", "import sagemaker\n", "\n", "from search_utils import helpers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Define common variables\n", "\n", "#Creating a sagemaker session\n", "sagemaker_session = sagemaker.Session()\n", "\n", "#We'll be using the sagemaker default bucket\n", "#Feel free to change this to another bucket name and make sure it's the same across all four notebooks\n", "bucket_name = sagemaker_session.default_bucket()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generate_unique_id():\n", " return str(uuid.uuid4())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Building the docker image " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we'll build a custom docker container in order to use it with the SageMaker processing jobs.\n", "\n", "Within the docker we'll install the libraries defined in the requirements.txt file.\n", "\n", "We'll also upload the source code (helper functions, processing functions etc) under \"/opt/source_code/\" so they are accessible during runtime." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "cd ../\n", "sh build_and_push.sh" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Make sure you replace the following variable with your account id and region\n", "#You can also copy past the ECR uri from the logs of the previous cell\n", "ecr_uri = \"-dkr.ecr-.amazonaws.com/sm-search:latest\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#uploading the search_utils files so they are accessible during runtime\n", "s3_client = boto3.client(\"s3\")\n", "for file_name in glob.glob(\"../src/search_utils/*.py\"):\n", " s3_client.upload_file(file_name, bucket_name, f\"search_knn_blog/code/{file_name.split('/')[-1]}\" )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2. Preprocessing " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.processing import ScriptProcessor\n", "script_processor = ScriptProcessor(\n", " image_uri=ecr_uri,\n", " role=sagemaker.get_execution_role(),\n", " instance_count=1,\n", " instance_type='ml.m5.4xlarge',\n", " command=[\"python3\"],\n", " volume_size_in_gb=50)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "outputs_hidden": true } }, "outputs": [], "source": [ "from sagemaker.processing import ProcessingInput, ProcessingOutput\n", "\n", "now = datetime.datetime.utcnow()\n", "now_string = now.strftime('%y%m%d%H%M%S%f')\n", "run_id = now_string[:-2]\n", "print(f\"run id : {run_id}\")\n", "preprocess_job_name = f\"search-preprocess-{run_id}\"\n", "\n", "s3_input_path=f\"s3://{bucket_name}/search_knn_blog/data/processed_data/\"\n", "s3_code_path=f\"s3://{bucket_name}/search_knn_blog/code/\"\n", "\n", "script_processor.run(job_name=preprocess_job_name,\n", " code='../src/preprocessing_main.py',\n", " inputs=[ProcessingInput(\n", " source=s3_input_path,\n", " destination='/opt/ml/processing/input'),\n", " ProcessingInput(\n", " source=s3_code_path,\n", " destination='/opt/ml/processing/input/code/search_utils/')],\n", " outputs=[\n", "ProcessingOutput(destination=f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs/{preprocess_job_name}/\",\n", " output_name='train_textual',\n", " source='/opt/ml/processing/train_textual'),\n", "ProcessingOutput(destination=f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs/{preprocess_job_name}/\",\n", " output_name='test_textual',\n", " source='/opt/ml/processing/test_textual'),\n", "ProcessingOutput(destination=f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs/{preprocess_job_name}/\",\n", " output_name='train_numerical',\n", " source='/opt/ml/processing/train_numerical'),\n", "ProcessingOutput(destination=f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs/{preprocess_job_name}/\",\n", " output_name='test_numerical',\n", " source='/opt/ml/processing/test_numerical'),\n", "ProcessingOutput(destination=f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs/{preprocess_job_name}/\",\n", " output_name='vocab',\n", " source='/opt/ml/processing/vocab'),\n", "ProcessingOutput(destination=f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs/{preprocess_job_name}/\",\n", " output_name='raw_vocab',\n", " source='/opt/ml/processing/raw_vocab')\n", " ],\n", " arguments=['--train-test-split-ratio', '0.2','--total-nb-of-records', '10000'],wait=False)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "outputs_hidden": true } }, "outputs": [], "source": [ "status = boto3.client(\"sagemaker\").describe_processing_job(ProcessingJobName=preprocess_job_name)[\"ProcessingJobStatus\"]\n", "\n", "while status == 'InProgress': \n", " status = boto3.client(\"sagemaker\").describe_processing_job(ProcessingJobName=preprocess_job_name)[\"ProcessingJobStatus\"]\n", " print(status)\n", " time.sleep(30)\n", " continue" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3. Glove embedding" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will be using the glove embedding to initiate the values of the word tokens. The GloVe embeddings are downloaded from here : https://nlp.stanford.edu/projects/glove/\n", "\n", "This data is made available under the Public Domain Dedication and License v1.0 whose full text can be found at: http://www.opendatacommons.org/licenses/pddl/1.0/.\n", "\n", "\n", "Let's start by pulling the glove embeddings locally then pushing them to S3 using the following commands:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "# This might time a few minutes\n", "mkdir /tmp/GloVe\n", "curl -Lo /tmp/GloVe/glove.840B.zip http://nlp.stanford.edu/data/glove.840B.300d.zip\n", "unzip /tmp/GloVe/glove.840B.zip -d /tmp/GloVe/\n", "rm /tmp/GloVe/glove.840B.zip" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "boto3.client(\"s3\").upload_file(\"/tmp/GloVe/glove.840B.300d.txt\",\n", " bucket_name, \"search_knn_blog/artefacts/glove.840B.300d.txt\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now go ahead and craete a processing job name that will parse the vocabulary generated in the previous section and output a trimmed version of the glove embeddings based on our vocabulary.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.processing import ScriptProcessor\n", "script_processor = ScriptProcessor(\n", " image_uri=ecr_uri,\n", " role=sagemaker.get_execution_role(),\n", " instance_count=1,\n", " instance_type='ml.m5.xlarge',\n", " command=[\"python3\"])\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.processing import ProcessingInput, ProcessingOutput\n", "\n", "now = datetime.datetime.utcnow()\n", "now_string = now.strftime('%y%m%d%H%M%S%f')\n", "run_id = now_string[:-2]\n", "print(f\"run id : {run_id}\")\n", "glove_job_name = f\"search-glove-{run_id}\"\n", "\n", "s3_code_path=f\"s3://{bucket_name}/search_knn_blog/code/\"\n", "\n", "script_processor.run(job_name=glove_job_name,\n", " code='../src/glove_embeddings_main.py',\n", " inputs=[ProcessingInput(\n", " source=s3_code_path,\n", " destination='/opt/ml/processing/input/code/search_utils/'),\n", " ProcessingInput(\n", " source=f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs/{preprocess_job_name}/\",\n", " destination='/opt/ml/processing/input_vocabulary'),\n", " ProcessingInput(\n", " source=f\"s3://{bucket_name}/search_knn_blog/artefacts/glove.840B.300d.txt\",\n", " destination='/opt/ml/processing/input_glove')\n", " ],\n", " outputs=[\n", " ProcessingOutput(destination=f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs/{glove_job_name}/\",\n", " output_name='trimmed_glove',\n", " source='/opt/ml/processing/trimmed_glove'),\n", " ProcessingOutput(destination=f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs/{glove_job_name}/\",\n", " output_name='vocab',\n", " source='/opt/ml/processing/vocab')\n", " ],\n", " arguments=['--train-test-split-ratio', '0.2'],\n", " wait=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "outputs_hidden": true } }, "outputs": [], "source": [ "status = boto3.client(\"sagemaker\").describe_processing_job(ProcessingJobName=glove_job_name)[\"ProcessingJobStatus\"]\n", "\n", "while status == 'InProgress': \n", " status = boto3.client(\"sagemaker\").describe_processing_job(ProcessingJobName=glove_job_name)[\"ProcessingJobStatus\"]\n", " print(status)\n", " \n", " time.sleep(30)\n", " continue" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"This is the processing job name you will need during inference : {glove_job_name}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 4. Training " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "word_to_id = helpers.read_json_from_s3(bucket_name, f\"search_knn_blog/sagemaker-runs/{glove_job_name}/vocab.json\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "outputs_hidden": true } }, "outputs": [], "source": [ "from sagemaker.amazon.amazon_estimator import get_image_uri\n", "from sagemaker.session import s3_input\n", "\n", "now = datetime.datetime.utcnow()\n", "now_string = now.strftime('%y%m%d%H%M%S%f')\n", "run_id = now_string[:-2]\n", "print(f\"run id : {run_id}\")\n", "\n", "training_job_name = f\"search-training-{run_id}\"\n", "output_path = os.path.join(f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs\", training_job_name)\n", "\n", "regressor = sagemaker.estimator.Estimator(get_image_uri(boto3.Session().region_name, 'object2vec'),\n", " sagemaker.get_execution_role(), \n", " train_instance_count=1, \n", " train_instance_type='ml.p3.8xlarge',\n", " output_path=output_path,\n", " sagemaker_session=sagemaker.Session())\n", "\n", "\n", "hyperparameters = {\n", " \"enc_dim\": 512, #The dimension of the output of the embedding layer.\n", " \"mlp_dim\": 256, #The dimension of the output from MLP layers.\n", " \"mlp_activation\": \"linear\",\n", " \"mlp_layers\": 2,\n", " \n", " \"output_layer\" : \"softmax\",#classification task\n", " \"num_classes\": 2,#0 and 1\n", "\n", " \"optimizer\" : \"adam\",\n", " \"learning_rate\" : 0.0004,\n", " \"mini_batch_size\": 256,\n", " \"epochs\" : 20,\n", "\n", " \"enc0_max_seq_len\": 200,\n", " \"enc1_max_seq_len\": 200,\n", "\n", " \"enc0_network\": \"bilstm\", #The network model for the enc0 encoder.\n", " \"enc1_network\": \"enc0\", #same as enc0_network\n", "\n", " \"enc0_token_embedding_dim\": 300, #The output dimension of the enc0 token embedding layer.\n", " \"enc1_token_embedding_dim\": 300, #The output dimension of the enc1 token embedding layer.\n", " \n", " \"enc0_vocab_file\" : \"vocab.json\", #The vocabulary file for mapping pretrained enc0 token embedding vectors to numerical vocabulary IDs.\n", " \"enc1_vocab_file\" : \"vocab.json\", #same as enc0_vocab_file\n", "\n", " \"enc0_vocab_size\" : len(word_to_id),#The vocabulary size of enc0 tokens.\n", " \"enc1_vocab_size\" : len(word_to_id),#The vocabulary size of enc1 tokens.\n", " \n", " \"enc0_pretrained_embedding_file\" : \"trimmed_glove.txt\",\n", " \"enc1_pretrained_embedding_file\" : \"trimmed_glove.txt\"\n", " \n", "}\n", "\n", "input_channels = {}\n", "s3_client = boto3.client('s3')\n", "\n", "input_channels[\"train\"] = s3_input(os.path.join(f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs\",\\\n", " preprocess_job_name,\n", " \"numerical_train_data.jsonl\"),\n", " distribution='FullyReplicated', \n", " content_type='application/jsonlines')\n", "\n", "input_channels[\"test\"] = s3_input(os.path.join(f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs\",\\\n", " preprocess_job_name,\n", " \"numerical_test_data.jsonl\"),\n", " distribution='FullyReplicated', \n", " content_type='application/jsonlines')\n", "\n", "input_channels['auxiliary'] = s3_input(os.path.join(f\"s3://{bucket_name}/search_knn_blog/sagemaker-runs\",\\\n", " glove_job_name), \n", " distribution='FullyReplicated', content_type='application/json')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "regressor.set_hyperparameters(**hyperparameters)\n", "regressor.fit(input_channels, job_name=training_job_name, wait=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "outputs_hidden": true } }, "outputs": [], "source": [ "status = boto3.client(\"sagemaker\").describe_training_job(TrainingJobName=training_job_name)[\"TrainingJobStatus\"]\n", "\n", "while status == 'InProgress': \n", " status = boto3.client(\"sagemaker\").describe_training_job(TrainingJobName=training_job_name)[\"TrainingJobStatus\"]\n", " print(status)\n", " \n", " time.sleep(30)\n", " continue\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(f\"This is the training job name you will need during inference : {training_job_name}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dict_metrics = [{metric[\"MetricName\"]:metric['Value']} for metric in boto3.client(\"sagemaker\").describe_training_job(TrainingJobName=training_job_name)[\"FinalMetricDataList\"]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dict_metrics" ] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 4 }