{ "cells": [ { "cell_type": "markdown", "id": "timely-realtor", "metadata": { "papermill": { "duration": 0.018003, "end_time": "2021-06-03T00:09:48.368659", "exception": false, "start_time": "2021-06-03T00:09:48.350656", "status": "completed" }, "tags": [] }, "source": [ "# A/B Testing with Amazon SageMaker\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "16cdb56b", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "f97bf717", "metadata": { "papermill": { "duration": 0.018003, "end_time": "2021-06-03T00:09:48.368659", "exception": false, "start_time": "2021-06-03T00:09:48.350656", "status": "completed" }, "tags": [] }, "source": [ "\n", "In production ML workflows, data scientists and data engineers frequently try to improve their models in various ways, such as by performing [Perform Automatic Model Tuning](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html), training on additional or more-recent data, and improving feature selection. Performing A/B testing between a new model and an old model with production traffic can be an effective final step in the validation process for a new model. In A/B testing, you test different variants of your models and compare how each variant performs relative to each other. You then choose the best-performing model to replace a previously-existing model new version delivers better performance than the previously-existing version.\n", "\n", "Amazon SageMaker enables you to test multiple models or model versions behind the same endpoint using production variants. Each production variant identifies a machine learning (ML) model and the resources deployed for hosting the model. You can distribute endpoint invocation requests across multiple production variants by providing the traffic distribution for each variant, or you can invoke a specific variant directly for each request.\n", "\n", "In this notebook we'll:\n", "* Evaluate models by invoking specific variants\n", "* Gradually release a new model by specifying traffic distribution" ] }, { "cell_type": "markdown", "id": "czech-deadline", "metadata": { "papermill": { "duration": 0.018034, "end_time": "2021-06-03T00:09:48.404973", "exception": false, "start_time": "2021-06-03T00:09:48.386939", "status": "completed" }, "tags": [] }, "source": [ "### Prerrequisites\n", "\n", "First we ensure we have an updated version of boto3, which includes the latest SageMaker features:" ] }, { "cell_type": "code", "execution_count": 1, "id": "significant-budget", "metadata": { "collapsed": true, "execution": { "iopub.execute_input": "2021-06-03T00:09:48.445522Z", "iopub.status.busy": "2021-06-03T00:09:48.445063Z", "iopub.status.idle": "2021-06-03T00:10:01.618823Z", "shell.execute_reply": "2021-06-03T00:10:01.618403Z" }, "jupyter": { "outputs_hidden": true }, "papermill": { "duration": 13.194939, "end_time": "2021-06-03T00:10:01.618937", "exception": false, "start_time": "2021-06-03T00:09:48.423998", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting awscli\n", " Downloading awscli-1.19.86.tar.gz (1.4 MB)\n", "\u001b[K |████████████████████████████████| 1.4 MB 8.0 MB/s eta 0:00:01\n", "\u001b[?25hCollecting botocore==1.20.86\n", " Downloading botocore-1.20.86-py2.py3-none-any.whl (7.6 MB)\n", "\u001b[K |████████████████████████████████| 7.6 MB 14.8 MB/s eta 0:00:01\n", "\u001b[?25hCollecting docutils<0.16,>=0.10\n", " Downloading docutils-0.15.2-py3-none-any.whl (547 kB)\n", "\u001b[K |████████████████████████████████| 547 kB 100.4 MB/s eta 0:00:01\n", "\u001b[?25hCollecting s3transfer<0.5.0,>=0.4.0\n", " Downloading s3transfer-0.4.2-py2.py3-none-any.whl (79 kB)\n", "\u001b[K |████████████████████████████████| 79 kB 17.2 MB/s eta 0:00:01\n", "\u001b[?25hRequirement already satisfied, skipping upgrade: PyYAML<5.5,>=3.10 in /home/ubuntu/anaconda3/envs/python3/lib/python3.6/site-packages (from awscli) (5.3.1)\n", "Requirement already satisfied, skipping upgrade: colorama<0.4.4,>=0.2.5 in /home/ubuntu/anaconda3/envs/python3/lib/python3.6/site-packages (from awscli) (0.4.3)\n", "Requirement already satisfied, skipping upgrade: rsa<4.8,>=3.1.2 in /home/ubuntu/anaconda3/envs/python3/lib/python3.6/site-packages (from awscli) (4.6)\n", "Requirement already satisfied, skipping upgrade: jmespath<1.0.0,>=0.7.1 in /home/ubuntu/anaconda3/envs/python3/lib/python3.6/site-packages (from botocore==1.20.86->awscli) (0.10.0)\n", "Requirement already satisfied, skipping upgrade: python-dateutil<3.0.0,>=2.1 in /home/ubuntu/anaconda3/envs/python3/lib/python3.6/site-packages (from botocore==1.20.86->awscli) (2.8.1)\n", "Requirement already satisfied, skipping upgrade: urllib3<1.27,>=1.25.4 in /home/ubuntu/anaconda3/envs/python3/lib/python3.6/site-packages (from botocore==1.20.86->awscli) (1.25.10)\n", "Requirement already satisfied, skipping upgrade: pyasn1>=0.1.3 in /home/ubuntu/anaconda3/envs/python3/lib/python3.6/site-packages (from rsa<4.8,>=3.1.2->awscli) (0.4.8)\n", "Requirement already satisfied, skipping upgrade: six>=1.5 in /home/ubuntu/anaconda3/envs/python3/lib/python3.6/site-packages (from python-dateutil<3.0.0,>=2.1->botocore==1.20.86->awscli) (1.14.0)\n", "Building wheels for collected packages: awscli\n", " Building wheel for awscli (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for awscli: filename=awscli-1.19.86-py2.py3-none-any.whl size=3627022 sha256=937748476fe20b446ff72e96e916129e6fe88a365e4332400821c598cd86fdb0\n", " Stored in directory: /home/ubuntu/.cache/pip/wheels/87/ec/5f/35e096e78927f9844f5d13163d29f38e57474aabfed2fe76e9\n", "Successfully built awscli\n", "\u001b[31mERROR: boto3 1.16.36 has requirement botocore<1.20.0,>=1.19.36, but you'll have botocore 1.20.86 which is incompatible.\u001b[0m\n", "\u001b[31mERROR: boto3 1.16.36 has requirement s3transfer<0.4.0,>=0.3.0, but you'll have s3transfer 0.4.2 which is incompatible.\u001b[0m\n", "Installing collected packages: botocore, docutils, s3transfer, awscli\n", " Attempting uninstall: botocore\n", " Found existing installation: botocore 1.19.36\n", " Uninstalling botocore-1.19.36:\n", " Successfully uninstalled botocore-1.19.36\n", " Attempting uninstall: docutils\n", " Found existing installation: docutils 0.16\n", " Uninstalling docutils-0.16:\n", " Successfully uninstalled docutils-0.16\n", " Attempting uninstall: s3transfer\n", " Found existing installation: s3transfer 0.3.3\n", " Uninstalling s3transfer-0.3.3:\n", " Successfully uninstalled s3transfer-0.3.3\n", "Successfully installed awscli-1.19.86 botocore-1.20.86 docutils-0.15.2 s3transfer-0.4.2\n" ] } ], "source": [ "!pip install -U awscli" ] }, { "cell_type": "markdown", "id": "narrow-direction", "metadata": { "papermill": { "duration": 0.0286, "end_time": "2021-06-03T00:10:01.676783", "exception": false, "start_time": "2021-06-03T00:10:01.648183", "status": "completed" }, "tags": [] }, "source": [ "## Configuration" ] }, { "cell_type": "markdown", "id": "flush-pencil", "metadata": { "papermill": { "duration": 0.028203, "end_time": "2021-06-03T00:10:01.733636", "exception": false, "start_time": "2021-06-03T00:10:01.705433", "status": "completed" }, "tags": [] }, "source": [ "Let's set up some required imports and basic initial variables:" ] }, { "cell_type": "code", "execution_count": 2, "id": "iraqi-drain", "metadata": { "execution": { "iopub.execute_input": "2021-06-03T00:10:01.796772Z", "iopub.status.busy": "2021-06-03T00:10:01.796260Z", "iopub.status.idle": "2021-06-03T00:10:03.618225Z", "shell.execute_reply": "2021-06-03T00:10:03.618603Z" }, "isConfigCell": true, "papermill": { "duration": 1.857027, "end_time": "2021-06-03T00:10:03.618747", "exception": false, "start_time": "2021-06-03T00:10:01.761720", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 966 ms, sys: 110 ms, total: 1.08 s\n", "Wall time: 1.82 s\n" ] } ], "source": [ "%%time\n", "%matplotlib inline\n", "from datetime import datetime, timedelta\n", "import time\n", "import os\n", "import boto3\n", "import re\n", "import json\n", "from sagemaker import get_execution_role, session\n", "from sagemaker.s3 import S3Downloader, S3Uploader\n", "\n", "region = boto3.Session().region_name\n", "role = get_execution_role()\n", "sm_session = session.Session(boto3.Session())\n", "sm = boto3.Session().client(\"sagemaker\")\n", "sm_runtime = boto3.Session().client(\"sagemaker-runtime\")\n", "\n", "# You can use a different bucket, but make sure the role you chose for this notebook\n", "# has the s3:PutObject permissions. This is the bucket into which the model artifacts will be uploaded\n", "bucket = sm_session.default_bucket()\n", "prefix = \"sagemaker/DEMO-VariantTargeting\"" ] }, { "cell_type": "markdown", "id": "floral-evidence", "metadata": { "papermill": { "duration": 0.028436, "end_time": "2021-06-03T00:10:03.675664", "exception": false, "start_time": "2021-06-03T00:10:03.647228", "status": "completed" }, "tags": [] }, "source": [ "## Step 1: Create and deploy the models\n", "\n", "### First, we upload our pre-trained models to Amazon S3\n", "This code uploads two pre-trained XGBoost models that are ready for you to deploy. These models were trained using the XGB Churn Prediction Notebook in SageMaker. You can also use your own pre-trained models in this step. If you already have a pretrained model in Amazon S3, you can add it by specifying the s3_key.\n", "\n", "The models in this example are used to predict the probability of a mobile customer leaving their current mobile operator. The dataset we use is publicly available and was mentioned in the book [Discovering Knowledge in Data](https://www.amazon.com/dp/0470908742/) by Daniel T. Larose. It is attributed by the author to the University of California Irvine Repository of Machine Learning Datasets." ] }, { "cell_type": "code", "execution_count": 3, "id": "accessory-batch", "metadata": { "execution": { "iopub.execute_input": "2021-06-03T00:10:03.744991Z", "iopub.status.busy": "2021-06-03T00:10:03.744242Z", "iopub.status.idle": "2021-06-03T00:10:04.088276Z", "shell.execute_reply": "2021-06-03T00:10:04.088661Z" }, "papermill": { "duration": 0.384792, "end_time": "2021-06-03T00:10:04.088803", "exception": false, "start_time": "2021-06-03T00:10:03.704011", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "('s3://sagemaker-us-west-2-688520471316/sagemaker/DEMO-VariantTargeting/xgb-churn-prediction-model.tar.gz',\n", " 's3://sagemaker-us-west-2-688520471316/sagemaker/DEMO-VariantTargeting/xgb-churn-prediction-model2.tar.gz')" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_url = S3Uploader.upload(\n", " local_path=\"model/xgb-churn-prediction-model.tar.gz\", desired_s3_uri=f\"s3://{bucket}/{prefix}\"\n", ")\n", "model_url2 = S3Uploader.upload(\n", " local_path=\"model/xgb-churn-prediction-model2.tar.gz\", desired_s3_uri=f\"s3://{bucket}/{prefix}\"\n", ")\n", "model_url, model_url2" ] }, { "cell_type": "markdown", "id": "superior-estate", "metadata": { "papermill": { "duration": 0.02864, "end_time": "2021-06-03T00:10:04.146237", "exception": false, "start_time": "2021-06-03T00:10:04.117597", "status": "completed" }, "tags": [] }, "source": [ "### Next, we create our model definitions\n", "Start with deploying the pre-trained churn prediction models. Here, you create the model objects with the image and model data." ] }, { "cell_type": "code", "execution_count": 5, "id": "tracked-pocket", "metadata": { "execution": { "iopub.execute_input": "2021-06-03T00:10:04.217051Z", "iopub.status.busy": "2021-06-03T00:10:04.216261Z", "iopub.status.idle": "2021-06-03T00:10:05.148114Z", "shell.execute_reply": "2021-06-03T00:10:05.148476Z" }, "papermill": { "duration": 0.973816, "end_time": "2021-06-03T00:10:05.148613", "exception": false, "start_time": "2021-06-03T00:10:04.174797", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'DEMO-xgb-churn-pred2-2021-06-03-18-12-17'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sagemaker.image_uris import retrieve\n", "\n", "\n", "model_name = f\"DEMO-xgb-churn-pred-{datetime.now():%Y-%m-%d-%H-%M-%S}\"\n", "model_name2 = f\"DEMO-xgb-churn-pred2-{datetime.now():%Y-%m-%d-%H-%M-%S}\"\n", "image_uri = retrieve(\"xgboost\", boto3.Session().region_name, \"0.90-1\")\n", "image_uri2 = retrieve(\"xgboost\", boto3.Session().region_name, \"0.90-2\")\n", "\n", "sm_session.create_model(\n", " name=model_name, role=role, container_defs={\"Image\": image_uri, \"ModelDataUrl\": model_url}\n", ")\n", "\n", "sm_session.create_model(\n", " name=model_name2, role=role, container_defs={\"Image\": image_uri2, \"ModelDataUrl\": model_url2}\n", ")" ] }, { "cell_type": "markdown", "id": "compressed-payment", "metadata": { "papermill": { "duration": 0.029456, "end_time": "2021-06-03T00:10:05.207592", "exception": false, "start_time": "2021-06-03T00:10:05.178136", "status": "completed" }, "tags": [] }, "source": [ "### Create variants\n", "\n", "We now create two variants, each with its own different model (these could also have different instance types and counts).\n", "\n", "We set an initial_weight of “1” for both variants: this means 50% of our requests go to Variant1, and the remaining 50% of all requests to Variant2. (The sum of weights across both variants is 2 and each variant has weight assignment of 1. This implies each variant receives 1/2, or 50%, of the total traffic.)" ] }, { "cell_type": "code", "execution_count": 6, "id": "certified-iceland", "metadata": { "execution": { "iopub.execute_input": "2021-06-03T00:10:05.271661Z", "iopub.status.busy": "2021-06-03T00:10:05.271037Z", "iopub.status.idle": "2021-06-03T00:10:05.273524Z", "shell.execute_reply": "2021-06-03T00:10:05.273874Z" }, "papermill": { "duration": 0.037206, "end_time": "2021-06-03T00:10:05.274004", "exception": false, "start_time": "2021-06-03T00:10:05.236798", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "({'ModelName': 'DEMO-xgb-churn-pred-2021-06-03-18-12-17',\n", " 'InstanceType': 'ml.m5.xlarge',\n", " 'InitialInstanceCount': 1,\n", " 'VariantName': 'Variant1',\n", " 'InitialVariantWeight': 1},\n", " {'ModelName': 'DEMO-xgb-churn-pred2-2021-06-03-18-12-17',\n", " 'InstanceType': 'ml.m5.xlarge',\n", " 'InitialInstanceCount': 1,\n", " 'VariantName': 'Variant2',\n", " 'InitialVariantWeight': 1})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sagemaker.session import production_variant\n", "\n", "variant1 = production_variant(\n", " model_name=model_name,\n", " instance_type=\"ml.c5.4xlarge\",\n", " initial_instance_count=1,\n", " variant_name=\"Variant1\",\n", " initial_weight=1,\n", ")\n", "variant2 = production_variant(\n", " model_name=model_name2,\n", " instance_type=\"ml.c5.4xlarge\",\n", " initial_instance_count=1,\n", " variant_name=\"Variant2\",\n", " initial_weight=1,\n", ")\n", "\n", "(variant1, variant2)" ] }, { "cell_type": "markdown", "id": "different-lending", "metadata": { "papermill": { "duration": 0.029624, "end_time": "2021-06-03T00:10:05.333520", "exception": false, "start_time": "2021-06-03T00:10:05.303896", "status": "completed" }, "tags": [] }, "source": [ "### Deploy\n", "\n", "Let's go ahead and deploy our two variants to a SageMaker endpoint:" ] }, { "cell_type": "code", "execution_count": 7, "id": "wrong-building", "metadata": { "execution": { "iopub.execute_input": "2021-06-03T00:10:05.398312Z", "iopub.status.busy": "2021-06-03T00:10:05.397790Z", "iopub.status.idle": "2021-06-03T00:13:06.360172Z", "shell.execute_reply": "2021-06-03T00:13:06.359350Z" }, "papermill": { "duration": 180.997109, "end_time": "2021-06-03T00:13:06.360381", "exception": true, "start_time": "2021-06-03T00:10:05.363272", "status": "failed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "EndpointName=DEMO-xgb-churn-pred-2021-06-03-18-13-07\n", "-------------!" ] }, { "data": { "text/plain": [ "'DEMO-xgb-churn-pred-2021-06-03-18-13-07'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "endpoint_name = f\"DEMO-xgb-churn-pred-{datetime.now():%Y-%m-%d-%H-%M-%S}\"\n", "print(f\"EndpointName={endpoint_name}\")\n", "\n", "sm_session.endpoint_from_production_variants(\n", " name=endpoint_name, production_variants=[variant1, variant2]\n", ")" ] }, { "cell_type": "markdown", "id": "authentic-visitor", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "## Step 2: Invoke the deployed models\n", "\n", "You can now send data to this endpoint to get inferences in real time.\n", "\n", "This step invokes the endpoint with included sample data for about 2 minutes. " ] }, { "cell_type": "code", "execution_count": 8, "id": "latest-woman", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sending test traffic to the endpoint DEMO-xgb-churn-pred-2021-06-03-18-13-07. \n", "Please wait...\n", "........................................................................................................................Done!\n" ] } ], "source": [ "# get a subset of test data for a quick test\n", "!tail -120 test_data/test-dataset-input-cols.csv > test_data/test_sample_tail_input_cols.csv\n", "print(f\"Sending test traffic to the endpoint {endpoint_name}. \\nPlease wait...\")\n", "\n", "with open(\"test_data/test_sample_tail_input_cols.csv\", \"r\") as f:\n", " for row in f:\n", " print(\".\", end=\"\", flush=True)\n", " payload = row.rstrip(\"\\n\")\n", " sm_runtime.invoke_endpoint(EndpointName=endpoint_name, ContentType=\"text/csv\", Body=payload)\n", " time.sleep(0.5)\n", "\n", "print(\"Done!\")" ] }, { "cell_type": "markdown", "id": "resistant-detective", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "### Invocations per variant\n", "\n", "Amazon SageMaker emits metrics such as Latency and Invocations (full list of metrics [here](https://alpha-docs-aws.amazon.com/sagemaker/latest/dg/monitoring-cloudwatch.html)) for each variant in Amazon CloudWatch. Let’s query CloudWatch to get number of Invocations per variant, to show how invocations are split across variants:" ] }, { "cell_type": "code", "execution_count": null, "id": "dimensional-crest", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "cw = boto3.Session().client(\"cloudwatch\")\n", "\n", "\n", "def get_invocation_metrics_for_endpoint_variant(endpoint_name, variant_name, start_time, end_time):\n", " metrics = cw.get_metric_statistics(\n", " Namespace=\"AWS/SageMaker\",\n", " MetricName=\"Invocations\",\n", " StartTime=start_time,\n", " EndTime=end_time,\n", " Period=60,\n", " Statistics=[\"Sum\"],\n", " Dimensions=[\n", " {\"Name\": \"EndpointName\", \"Value\": endpoint_name},\n", " {\"Name\": \"VariantName\", \"Value\": variant_name},\n", " ],\n", " )\n", " return (\n", " pd.DataFrame(metrics[\"Datapoints\"])\n", " .sort_values(\"Timestamp\")\n", " .set_index(\"Timestamp\")\n", " .drop(\"Unit\", axis=1)\n", " .rename(columns={\"Sum\": variant_name})\n", " )\n", "\n", "\n", "def plot_endpoint_metrics(start_time=None):\n", " start_time = start_time or datetime.now() - timedelta(minutes=60)\n", " end_time = datetime.now()\n", " metrics_variant1 = get_invocation_metrics_for_endpoint_variant(\n", " endpoint_name, variant1[\"VariantName\"], start_time, end_time\n", " )\n", " metrics_variant2 = get_invocation_metrics_for_endpoint_variant(\n", " endpoint_name, variant2[\"VariantName\"], start_time, end_time\n", " )\n", " metrics_variants = metrics_variant1.join(metrics_variant2, how=\"outer\")\n", " metrics_variants.plot()\n", " return metrics_variants" ] }, { "cell_type": "code", "execution_count": null, "id": "framed-foundation", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "print(\"Waiting a minute for initial metric creation...\")\n", "time.sleep(60)\n", "plot_endpoint_metrics()" ] }, { "cell_type": "markdown", "id": "entitled-colors", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "### Invoke a specific variant\n", "\n", "Now, let’s use the new feature that was released today to invoke a specific variant. For this, we simply use the new parameter to define which specific ProductionVariant we want to invoke. Let us use this to invoke Variant1 for all requests." ] }, { "cell_type": "code", "execution_count": null, "id": "received-damage", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "\n", "predictions = \"\"\n", "\n", "print(f\"Sending test traffic to the endpoint {endpoint_name}. \\nPlease wait...\")\n", "with open(\"test_data/test_sample_tail_input_cols.csv\", \"r\") as f:\n", " for row in f:\n", " print(\".\", end=\"\", flush=True)\n", " payload = row.rstrip(\"\\n\")\n", " response = sm_runtime.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType=\"text/csv\",\n", " Body=payload,\n", " TargetVariant=variant1[\"VariantName\"],\n", " )\n", " predictions = \",\".join([predictions, response[\"Body\"].read().decode(\"utf-8\")])\n", " time.sleep(0.5)\n", "\n", "# Convert our predictions to a numpy array\n", "pred_np = np.fromstring(predictions[1:], sep=\",\")\n", "\n", "# Convert the prediction probabilities to binary predictions of either 1 or 0\n", "threshold = 0.5\n", "preds = np.where(pred_np > threshold, 1, 0)\n", "print(\"Done!\")" ] }, { "cell_type": "markdown", "id": "assured-twenty", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "When we again check the traffic per variant, this time we see that the number of invocations only incremented for Variant1, because all invocations were targeted at that variant:" ] }, { "cell_type": "code", "execution_count": null, "id": "passive-nation", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "time.sleep(20) # let metrics catch up\n", "plot_endpoint_metrics()" ] }, { "cell_type": "markdown", "id": "closed-square", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "## Step 3: Evaluate variant performance\n", "\n", "### Evaluating Variant 1\n", "\n", "Using the new targeting feature, let us evaluate the accuracy, precision, recall, F1 score, and ROC/AUC for Variant1:" ] }, { "cell_type": "code", "execution_count": null, "id": "listed-hormone", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "from sklearn import metrics\n", "from sklearn.metrics import roc_auc_score\n", "\n", "# Let's get the labels of our test set; we will use these to evaluate our predictions\n", "!tail -121 test_data/test-dataset.csv > test_data/test_dataset_sample_tail.csv\n", "df_with_labels = pd.read_csv(\"test_data/test_dataset_sample_tail.csv\")\n", "test_labels = df_with_labels.iloc[:, 0]\n", "labels = test_labels.to_numpy()\n", "\n", "# Calculate accuracy\n", "accuracy = sum(preds == labels) / len(labels)\n", "print(f\"Accuracy: {accuracy}\")\n", "\n", "# Calculate precision\n", "precision = sum(preds[preds == 1] == labels[preds == 1]) / len(preds[preds == 1])\n", "print(f\"Precision: {precision}\")\n", "\n", "# Calculate recall\n", "recall = sum(preds[preds == 1] == labels[preds == 1]) / len(labels[labels == 1])\n", "print(f\"Recall: {recall}\")\n", "\n", "# Calculate F1 score\n", "f1_score = 2 * (precision * recall) / (precision + recall)\n", "print(f\"F1 Score: {f1_score}\")\n", "\n", "# Calculate AUC\n", "auc = round(roc_auc_score(labels, preds), 4)\n", "print(\"AUC is \" + repr(auc))\n", "\n", "fpr, tpr, _ = metrics.roc_curve(labels, preds)\n", "\n", "plt.title(\"ROC Curve\")\n", "plt.plot(fpr, tpr, \"b\", label=\"AUC = %0.2f\" % auc)\n", "plt.legend(loc=\"lower right\")\n", "plt.plot([0, 1], [0, 1], \"r--\")\n", "plt.xlim([-0.1, 1.1])\n", "plt.ylim([-0.1, 1.1])\n", "plt.ylabel(\"True Positive Rate\")\n", "plt.xlabel(\"False Positive Rate\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "confirmed-belfast", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "### Next, we collect data for Variant2" ] }, { "cell_type": "code", "execution_count": null, "id": "interim-turner", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "predictions2 = \"\"\n", "print(f\"Sending test traffic to the endpoint {endpoint_name}. \\nPlease wait...\")\n", "with open(\"test_data/test_sample_tail_input_cols.csv\", \"r\") as f:\n", " for row in f:\n", " print(\".\", end=\"\", flush=True)\n", " payload = row.rstrip(\"\\n\")\n", " response = sm_runtime.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType=\"text/csv\",\n", " Body=payload,\n", " TargetVariant=variant2[\"VariantName\"],\n", " )\n", " predictions2 = \",\".join([predictions2, response[\"Body\"].read().decode(\"utf-8\")])\n", " time.sleep(0.5)\n", "\n", "# Convert to numpy array\n", "pred_np2 = np.fromstring(predictions2[1:], sep=\",\")\n", "\n", "# Convert to binary predictions\n", "thresh = 0.5\n", "preds2 = np.where(pred_np2 > threshold, 1, 0)\n", "\n", "print(\"Done!\")" ] }, { "cell_type": "markdown", "id": "ethical-cardiff", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "When we again check the traffic per variant, this time we see that the number of invocations only incremented for Variant2, because all invocations were targeted at that variant:" ] }, { "cell_type": "code", "execution_count": null, "id": "dense-emperor", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "time.sleep(60) # give metrics time to catch up\n", "plot_endpoint_metrics()" ] }, { "cell_type": "markdown", "id": "useful-shoot", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "### Evaluating Variant2 " ] }, { "cell_type": "code", "execution_count": null, "id": "creative-positive", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "# Calculate accuracy\n", "accuracy2 = sum(preds2 == labels) / len(labels)\n", "print(f\"Accuracy: {accuracy2}\")\n", "\n", "# Calculate precision\n", "precision2 = sum(preds2[preds2 == 1] == labels[preds2 == 1]) / len(preds2[preds2 == 1])\n", "print(f\"Precision: {precision2}\")\n", "\n", "# Calculate recall\n", "recall2 = sum(preds2[preds2 == 1] == labels[preds2 == 1]) / len(labels[labels == 1])\n", "print(f\"Recall: {recall2}\")\n", "\n", "# Calculate F1 score\n", "f1_score2 = 2 * (precision2 * recall2) / (precision2 + recall2)\n", "print(f\"F1 Score: {f1_score2}\")\n", "\n", "auc2 = round(roc_auc_score(labels, preds2), 4)\n", "print(\"AUC is \" + repr(auc2))\n", "\n", "fpr2, tpr2, _ = metrics.roc_curve(labels, preds2)\n", "\n", "plt.title(\"ROC Curve\")\n", "plt.plot(fpr2, tpr2, \"b\", label=\"AUC = %0.2f\" % auc2)\n", "plt.legend(loc=\"lower right\")\n", "plt.plot([0, 1], [0, 1], \"r--\")\n", "plt.xlim([-0.1, 1.1])\n", "plt.ylim([-0.1, 1.1])\n", "plt.ylabel(\"True Positive Rate\")\n", "plt.xlabel(\"False Positive Rate\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "colored-attention", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "We see that Variant2 is performing better for most of our defined metrics, so this is the one we’re likely to choose to dial up in production." ] }, { "cell_type": "markdown", "id": "legislative-fence", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "## Step 4: Dialing up our chosen variant in production\n", "\n", "Now that we have determined Variant2 to be better as compared to Variant1, we will shift more traffic to it. \n", "\n", "We can continue to use TargetVariant to continue invoking a chosen variant. A simpler approach is to update the weights assigned to each variant using UpdateEndpointWeightsAndCapacities. This changes the traffic distribution to your production variants without requiring updates to your endpoint. \n", "\n", "Recall our variant weights are as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "acute-henry", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "{\n", " variant[\"VariantName\"]: variant[\"CurrentWeight\"]\n", " for variant in sm.describe_endpoint(EndpointName=endpoint_name)[\"ProductionVariants\"]\n", "}" ] }, { "cell_type": "markdown", "id": "changed-sleep", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "We'll first write a method to easily invoke our endpoint (a copy of what we had been previously doing):" ] }, { "cell_type": "code", "execution_count": null, "id": "moving-homework", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "def invoke_endpoint_for_two_minutes():\n", " with open(\"test_data/test-dataset-input-cols.csv\", \"r\") as f:\n", " for row in f:\n", " print(\".\", end=\"\", flush=True)\n", " payload = row.rstrip(\"\\n\")\n", " response = sm_runtime.invoke_endpoint(\n", " EndpointName=endpoint_name, ContentType=\"text/csv\", Body=payload\n", " )\n", " response[\"Body\"].read()\n", " time.sleep(1)" ] }, { "cell_type": "markdown", "id": "wound-biology", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "We invoke our endpoint for a bit, to show the even split in invocations:" ] }, { "cell_type": "code", "execution_count": null, "id": "funky-replication", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "invocation_start_time = datetime.now()\n", "invoke_endpoint_for_two_minutes()\n", "time.sleep(20) # give metrics time to catch up\n", "plot_endpoint_metrics(invocation_start_time)" ] }, { "cell_type": "markdown", "id": "prescribed-fourth", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "Now let us shift 75% of the traffic to Variant2 by assigning new weights to each variant using UpdateEndpointWeightsAndCapacities. Amazon SageMaker will now send 75% of the inference requests to Variant2 and remaining 25% of requests to Variant1. " ] }, { "cell_type": "code", "execution_count": null, "id": "tired-throat", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "sm.update_endpoint_weights_and_capacities(\n", " EndpointName=endpoint_name,\n", " DesiredWeightsAndCapacities=[\n", " {\"DesiredWeight\": 25, \"VariantName\": variant1[\"VariantName\"]},\n", " {\"DesiredWeight\": 75, \"VariantName\": variant2[\"VariantName\"]},\n", " ],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "configured-quantum", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "print(\"Waiting for update to complete\")\n", "while True:\n", " status = sm.describe_endpoint(EndpointName=endpoint_name)[\"EndpointStatus\"]\n", " if status in [\"InService\", \"Failed\"]:\n", " print(\"Done\")\n", " break\n", " print(\".\", end=\"\", flush=True)\n", " time.sleep(1)\n", "\n", "{\n", " variant[\"VariantName\"]: variant[\"CurrentWeight\"]\n", " for variant in sm.describe_endpoint(EndpointName=endpoint_name)[\"ProductionVariants\"]\n", "}" ] }, { "cell_type": "markdown", "id": "different-tennis", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "Now let's check how that has impacted invocation metrics:" ] }, { "cell_type": "code", "execution_count": null, "id": "acute-doctrine", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "invoke_endpoint_for_two_minutes()\n", "time.sleep(20) # give metrics time to catch up\n", "plot_endpoint_metrics(invocation_start_time)" ] }, { "cell_type": "markdown", "id": "quantitative-column", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "We can continue to monitor our metrics and when we're satisfied with a variant's performance, we can route 100% of the traffic over the variant. We used UpdateEndpointWeightsAndCapacities to update the traffic assignments for the variants. The weight for Variant1 is set to 0 and the weight for Variant2 is set to 1. Therefore, Amazon SageMaker will send 100% of all inference requests to Variant2." ] }, { "cell_type": "code", "execution_count": null, "id": "higher-designation", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "sm.update_endpoint_weights_and_capacities(\n", " EndpointName=endpoint_name,\n", " DesiredWeightsAndCapacities=[\n", " {\"DesiredWeight\": 0, \"VariantName\": variant1[\"VariantName\"]},\n", " {\"DesiredWeight\": 1, \"VariantName\": variant2[\"VariantName\"]},\n", " ],\n", ")\n", "print(\"Waiting for update to complete\")\n", "while True:\n", " status = sm.describe_endpoint(EndpointName=endpoint_name)[\"EndpointStatus\"]\n", " if status in [\"InService\", \"Failed\"]:\n", " print(\"Done\")\n", " break\n", " print(\".\", end=\"\", flush=True)\n", " time.sleep(1)\n", "\n", "{\n", " variant[\"VariantName\"]: variant[\"CurrentWeight\"]\n", " for variant in sm.describe_endpoint(EndpointName=endpoint_name)[\"ProductionVariants\"]\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "fitted-enzyme", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "invoke_endpoint_for_two_minutes()\n", "time.sleep(20) # give metrics time to catch up\n", "plot_endpoint_metrics(invocation_start_time)" ] }, { "cell_type": "markdown", "id": "loose-constraint", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "The Amazon CloudWatch metrics for the total invocations for each variant below shows us that all inference requests are being processed by Variant2 and there are no inference requests processed by Variant1.\n", "\n", "You can now safely update your endpoint and delete Variant1 from your endpoint. You can also continue testing new models in production by adding new variants to your endpoint and following steps 2 - 4. " ] }, { "cell_type": "markdown", "id": "binding-hardwood", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "## Delete the endpoint\n", "\n", "If you do not plan to use this endpoint further, you should delete the endpoint to avoid incurring additional charges." ] }, { "cell_type": "code", "execution_count": null, "id": "supreme-samoa", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "scrolled": true, "tags": [] }, "outputs": [], "source": [ "sm_session.delete_endpoint(endpoint_name)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4e9fe3c6", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/sagemaker_endpoints|a_b_testing|a_b_testing.ipynb)\n" ] } ], "metadata": { "anaconda-cloud": {}, "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science 2.0)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-west-2:236514542706:image/sagemaker-data-science-38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.", "papermill": { "default_parameters": {}, "duration": 199.476853, "end_time": "2021-06-03T00:13:06.967499", "environment_variables": {}, "exception": true, "input_path": "a_b_testing.ipynb", "output_path": "/opt/ml/processing/output/a_b_testing-2021-06-03-00-05-59.ipynb", "parameters": { "kms_key": "arn:aws:kms:us-west-2:521695447989:key/6e9984db-50cf-4c7e-926c-877ec47a8b25" }, "start_time": "2021-06-03T00:09:47.490646", "version": "2.3.3" } }, "nbformat": 4, "nbformat_minor": 5 }