# Leverage deployment guardrails to update a HuggangFace SageMaker Inference endpoint using canary traffic shifting

***
This notebooks is designed to run on `Python 3 (Data Science 2.0)` kernel in Amazon SageMaker Studio
***

This notebook is developed based on the [SageMaker github examples](https://github.com/aws/amazon-sagemaker-examples/tree/main/sagemaker-inference-deployment-guardrails)

We will perform following steps:
1. [Introduction](#Introduction)
2. [Setup](#Setup)
3. [Step 1: Deploy the models created in the previous notebooks](#Step-1:-Deploy-the-models-created-in-the-previous-notebooks)
4. [Step 2: Invoke Endpoint](#Step-2:-Invoke-Endpoint)
5. [Step 3: Create CloudWatch alarms to monitor Endpoint performance](#Step-3:-Create-CloudWatch-alarms-to-monitor-Endpoint-performance)
6. [Step 4: Update Endpoint with deployment configurations](#Step-4:-Update-Endpoint-with-deployment-configurations)

## Introduction

Deployment guardrails are a set of model deployment options in Amazon SageMaker Inference to update your machine learning models in production. Using the fully managed deployment guardrails, you can control the switch from the current model in production to a new one. Traffic shifting modes, such as canary and linear, give you granular control over the traffic shifting process from your current model to the new one during the course of the update. There are also built-in safeguards such as auto-rollbacks that help you catch issues early and take corrective action before they impact production.

We support blue-green deployment with multiple traffic shifting modes. A traffic shifting mode is a configuration that specifies how endpoint traffic is routed to a new fleet containing your updates. The following traffic shifting modes provide you with different levels of control over the endpoint update process:

* **All-At-Once Traffic Shifting** : shifts all of your endpoint traffic from the blue fleet to the green fleet. Once the traffic has shifted to the green fleet, your pre-specified Amazon CloudWatch alarms begin monitoring the green fleet for a set amount of time (the “baking period”). If no alarms are triggered during the baking period, then the blue fleet is terminated.
* **Canary Traffic Shifting** : lets you shift one small portion of your traffic (a “canary”) to the green fleet and monitor it for a baking period. If the canary succeeds on the green fleet, then the rest of the traffic is shifted from the blue fleet to the green fleet before terminating the blue fleet.
* **Linear Traffic Shifting** : provides even more customization over how many traffic-shifting steps to make and what percentage of traffic to shift for each step. While canary shifting lets you shift traffic in two steps, linear shifting extends this to n number of linearly spaced steps.


The Deployment guardrails for Amazon SageMaker Inference endpoints feature also allows customers to specify conditions/alarms based on Endpoint invocation metrics from CloudWatch to detect model performance regressions and trigger automatic rollback.

In this notebook we'll update endpoint with following deployment configurations:
 * Blue/Green update policy with **Canary traffic shifting option**
 * Configure CloudWatch alarms to monitor model performance and trigger auto-rollback action.
 
To demonstrate Canary deployments and the auto-rollback feature, we will update an Endpoint with an incompatible model version and deploy it as a Canary fleet, taking a small percentage of the traffic. Requests sent to this Canary fleet will result in errors, which will be used to trigger a rollback using pre-specified CloudWatch alarms. Finally, we will also demonstrate a success scenario where no alarms are tripped and the update succeeds. 

This notebook is organized in 4 steps -
* Step 1 creates the models and Endpoint Configurations required for the 3 scenarios - the baseline, the update containing the incompatible model version and the update containing the correct model version. 
* Step 2 invokes the baseline Endpoint prior to the update. 
* Step 3 specifies the CloudWatch alarms used to trigger the rollbacks. 
* Finally in step 4, we update the endpoint to trigger a rollback and demonstrate a successful update. 

## Setup 
Ensure that you have an updated version of boto3, which includes the latest SageMaker features:

In [None]:
%matplotlib inline

import time
import os
import sys
import boto3
import datetime
import sagemaker
from sagemaker import get_execution_role

p = os.path.abspath('..')
if p not in sys.path:
 sys.path.append(p)
import utils

sm_session = sagemaker.Session()
role = get_execution_role()
region = sm_session.boto_region_name
bucket = sm_session.default_bucket()
sm_client = sm_session.sagemaker_client
sm_runtime = sm_session.sagemaker_runtime_client
cw = boto3.Session().client("cloudwatch")
prefix = "sagemaker/huggingface-pytorch-sentiment-analysis"

time_now = f'{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}'

In [None]:
%store
%store -r

## Step 1: Deploy the models created in the previous notebooks

### First, we create endpoint configurations based on the previously created models 


The models in this example are used to analyse the sentiment of a given sentence. The dataset we use is based on a subset of a Kaggle dataset available [here](https://www.kaggle.com/datasets/sbhatti/financial-sentiment-analysis). 

We now create three EndpointConfigs, corresponding to the three Models we created in the previous step.

In [None]:
ep_config_name_roberta = f"hf-EpConfig-roberta-{time_now}"
ep_config_name_distilbert = f"hf-EpConfig-distilbert-{time_now}"
ep_config_name_roberta_mme = f"hf-EpConfig-roberta-mme-{time_now}"

print(f"Endpoint Config 1: {ep_config_name_roberta}")
print(f"Endpoint Config 2: {ep_config_name_distilbert}")
print(f"Endpoint Config 3: {ep_config_name_roberta_mme}")

resp = sm_client.create_endpoint_config(
 EndpointConfigName=ep_config_name_roberta,
 ProductionVariants=[
 {
 "VariantName": "AllTraffic",
 "ModelName": roberta_model_name,
 "InstanceType": deploy_instance_type,
 "InitialInstanceCount": 3,
 }
 ],
)
print(f"Created Endpoint Config: {resp}")
time.sleep(5)

resp = sm_client.create_endpoint_config(
 EndpointConfigName=ep_config_name_distilbert,
 ProductionVariants=[
 {
 "VariantName": "AllTraffic",
 "ModelName": distilbert_model_name,
 "InstanceType": deploy_instance_type,
 "InitialInstanceCount": 3,
 }
 ],
)
print(f"Created Endpoint Config: {resp}")
time.sleep(5)

resp = sm_client.create_endpoint_config(
 EndpointConfigName=ep_config_name_roberta_mme,
 ProductionVariants=[
 {
 "VariantName": "AllTraffic",
 "ModelName": roberta_mme_model_name,
 "InstanceType": deploy_instance_type,
 "InitialInstanceCount": 3,
 }
 ],
)
print(f"Created Endpoint Config: {resp}")
time.sleep(5)

### Create Endpoint

Deploy the roberta model to a new SageMaker endpoint:

In [None]:
endpoint_name = f"hf-deployment-guardrails-canary-{time_now}"
print(f"Endpoint Name: {endpoint_name}")

resp = sm_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=ep_config_name_roberta_mme)
print(f"\nCreated Endpoint: {resp}")

Wait for the endpoint creation to complete.

In [None]:
%%time
utils.endpoint_creation_wait(endpoint_name)

## Step 2: Invoke Endpoint

You can now send data to this endpoint to get inferences in real time.

This step invokes the endpoint with included sample data with maximum invocations count and waiting intervals. 

In [None]:
utils.invoke_endpoint_max_invocations(endpoint_name, max_invocations=100)

### Invocations Metrics

Amazon SageMaker emits metrics such as Latency and Invocations per variant/Endpoint Config (full list of metrics [here](https://docs.aws.amazon.com/sagemaker/latest/dg/monitoring-cloudwatch.html)) in Amazon CloudWatch.

Query CloudWatch to get number of Invocations and latency metrics per variant and endpoint configuration.

### Plot endpoint invocation metrics:

Below, we are going to plot graphs to show the Invocations,Invocation4XXErrors,Invocation5XXErrors,ModelLatency and OverheadLatency against the Endpoint.

You will observe that there should be a flat line for Invocation4XXErrors and Invocation5XXErrors as we are using the correct invocation data, model and configs. Additionally, ModelLatency and OverheadLatency will start decreasing over time.

In [None]:
invocation_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, ep_config_name_roberta_mme, "AllTraffic", "Invocations", "Sum"
)
invocation_4xx_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "Invocation4XXErrors", "Sum"
)
invocation_5xx_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "Invocation5XXErrors", "Sum"
)
model_latency_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "ModelLatency", "Average"
)
overhead_latency_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "OverheadLatency", "Average"
)

## Step 3: Create CloudWatch alarms to monitor Endpoint performance

Create CloudWatch alarms to monitor Endpoint performance with following metrics:
* Invocation5XXErrors
* ModelLatency

Following metric dimensions are used to select the metric per Endpoint config and variant:
* EndpointName
* VariantName

In [None]:
error_alarm = f"TestAlarm-4XXErrors-{endpoint_name}"
latency_alarm = f"TestAlarm-ModelLatency-{endpoint_name}"

# alarm on 1 4xx error rate for 1 minute
utils.create_auto_rollback_alarm(
 error_alarm, endpoint_name, "AllTraffic", "Invocation4XXErrors", "Sum", 1
)
# alarm on model latency >= 200 ms for 1 minute
utils.create_auto_rollback_alarm(
 latency_alarm, endpoint_name, "AllTraffic", "ModelLatency", "Average", 200000
)
time.sleep(60)

In [None]:
# cw.describe_alarms(AlarmNames=[error_alarm, latency_alarm])

## Step 4: Update Endpoint with deployment configurations

Update the endpoint with deployment configurations and monitor the performance from CloudWatch metrics.


### BlueGreen update policy with Canary traffic shifting

We define the following deployment configuration to perform Blue/Green update strategy with Canary traffic shifting from old to new stack. The Canary traffic shifting option can reduce the blast ratio of a regressive update to the endpoint. In contrast, for the All-At-Once traffic shifting option, the invocation requests start failing at 100% after flipping the traffic. In the Canary mode, invocation requests are shifted to the new version of model gradually, preventing errors from impacting 100% of your traffic. Additionally, the auto-rollback alarms monitor the metrics during the canary stage.

### Rollback Case 
![Rollback case](images/scenario-canary-rollback.png)

Update the Endpoint with an incompatible model version with the test data format to simulate errors and trigger a rollback.

In [None]:
canary_deployment_config = {
 "BlueGreenUpdatePolicy": {
 "TrafficRoutingConfiguration": {
 "Type": "CANARY",
 "CanarySize": {
 "Type": "INSTANCE_COUNT", # or use "CAPACITY_PERCENT" as 30%, 50%
 "Value": 1,
 },
 "WaitIntervalInSeconds": 300, # wait for 5 minutes before enabling traffic on the rest of fleet
 },
 "TerminationWaitInSeconds": 120, # wait for 2 minutes before terminating the old stack
 "MaximumExecutionTimeoutInSeconds": 1800, # maximum timeout for deployment
 },
 "AutoRollbackConfiguration": {
 "Alarms": [{"AlarmName": error_alarm}, {"AlarmName": latency_alarm}],
 },
}

# update endpoint request with new DeploymentConfig parameter
sm_client.update_endpoint(
 EndpointName=endpoint_name,
 EndpointConfigName=ep_config_name_roberta,
 DeploymentConfig=canary_deployment_config,
)

In [None]:
sm_client.describe_endpoint(EndpointName=endpoint_name)

### We invoke the endpoint during the update operation is in progress.

**Note : Invoke endpoint in this notebook is in single thread mode, to stop the invoke requests please stop the cell execution**

The E's denote the errors generated from the incompatible model version in the canary fleet.

The purpose of the below cell is to simulate errors in the canary fleet. Since the nature of traffic shifting to the canary fleet is probabilistic, you should wait until you start seeing errors. Then, you may proceed to stop the execution of the below cell. If not aborted, cell will run for 100 invocations.

In [None]:
utils.invoke_endpoint_max_invocations(endpoint_name, max_invocations=300)

Wait for the update operation to complete and verify the automatic rollback.

In [None]:
utils.endpoint_update_wait(endpoint_name)

sm_client.describe_endpoint(EndpointName=endpoint_name)

Collect the endpoint metrics during the deployment:

Below, we are going to plot graphs to show the Invocations,Invocation4XXErrors and ModelLatency against the Endpoint.

You can expect to see as the new endpoint config-2 (erroneous due to model version) starts getting deployed, it encounters failure and leads to the rollback to endpoint config-1. This can be seen in the graphs below as the Invocation4XXErrors and ModelLatency increases during this rollback phase


In [None]:
invocation_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "Invocations", "Sum"
)
metrics_epc_roberta = utils.plot_endpoint_invocation_metrics(
 endpoint_name, ep_config_name_roberta, "AllTraffic", "Invocations", "Sum"
)
metrics_epc_roberta_mme = utils.plot_endpoint_invocation_metrics(
 endpoint_name, ep_config_name_roberta_mme, "AllTraffic", "Invocations", "Sum"
)

metrics_all = invocation_metrics.join([metrics_epc_roberta, metrics_epc_roberta_mme], how="outer")
metrics_all.plot(title="Invocations-Sum")

invocation_5xx_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "Invocation4XXErrors", "Sum"
)
model_latency_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "ModelLatency", "Average"
)

We can check the alarm history by the cloudwatch DescribeAlarmHistory api call. However, please note that this notebook execution role doesn't have the IAM policy to allow this action. You can add the below IAM policy to the SageMaker execution role of your studio user profile from the IAM console.
```json
{
 "Version": "2012-10-17",
 "Statement": [
 {
 "Sid": "VisualEditor0",
 "Effect": "Allow",
 "Action": "cloudwatch:DescribeAlarmHistory",
 "Resource": "*"
 }
 ]
}
```

Alternatively, you can open the Cloudwatch Alarm console page to view the alam stats. [Cloudwach console](https://ap-southeast-2.console.aws.amazon.com/cloudwatch/home?region=ap-southeast-2#alarmsV2:)

In [None]:
time.sleep(60)
cw.describe_alarm_history(AlarmName=error_alarm)

Let's take a look at the Success case where we use the same Canary deployment configuration but a valid endpoint configuration.

### Success Case
![Success case](images/scenario-canary-success.png)

Now we show the success case where the Endpoint Configuration is updated to a valid version (using the same Canary deployment config as the rollback case).

Update the endpoint with the same Canary deployment configuration:

In [None]:
# update endpoint with a valid version of DeploymentConfig

sm_client.update_endpoint(
 EndpointName=endpoint_name,
 EndpointConfigName=ep_config_name_distilbert,
 RetainDeploymentConfig=True,
)

In [None]:
sm_client.describe_endpoint(EndpointName=endpoint_name)

Invoke the endpoint during the update operation is in progress:

In [None]:
utils.invoke_endpoint_max_invocations(endpoint_name, max_invocations=300)

wait for the update operation to complete:

While waiting, you can go to SageMaker console to check the status of the endpoint and you can see the endpoint configuration status changing as shown in the diagram.
![during_update](images/during_endpoint_update.png)

In [None]:
utils.endpoint_update_wait(endpoint_name)

sm_client.describe_endpoint(EndpointName=endpoint_name)

Once the endpoint is in service, you can see the new endpoint configuration name is changed to the distilbert endpoint configuration as shown below:
![final_state](images/final_state_config.png)

Collect the endpoint metrics during the deployment:

Below, we are going to plot graphs to show the Invocations,Invocation5XXErrors and ModelLatency against the Endpoint.

You can expect to see that, as the new endpoint config-3 (correct model version) starts getting deployed, it takes over endpoint config-2 (incompatible due to model version) without any errors. This can be seen in the graphs below as the Invocation5XXErrors and ModelLatency decreases during this transition phase

In [None]:
invocation_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "Invocations", "Sum"
)
metrics_epc_1 = utils.plot_endpoint_invocation_metrics(
 endpoint_name, ep_config_name_roberta, "AllTraffic", "Invocations", "Sum"
)
metrics_epc_2 = utils.plot_endpoint_invocation_metrics(
 endpoint_name, ep_config_name_distilbert, "AllTraffic", "Invocations", "Sum"
)
metrics_epc_3 = utils.plot_endpoint_invocation_metrics(
 endpoint_name, ep_config_name_roberta_mme, "AllTraffic", "Invocations", "Sum"
)


invocation_4xx_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "Invocation4XXErrors", "Sum"
)
model_latency_metrics = utils.plot_endpoint_invocation_metrics(
 endpoint_name, None, "AllTraffic", "ModelLatency", "Average"
)

The Amazon CloudWatch metrics for the total invocations for each endpoint config shows how invocation requests are shifted from the old version to the new version during deployment.

You can now safely update your endpoint and monitor model regressions during deployment and trigger auto-rollback action.

# Cleanup 

If you do not plan to use this endpoint further, you should delete the endpoint to avoid incurring additional charges and clean up other resources created in this notebook.

In [None]:
sm_client.delete_endpoint(EndpointName=endpoint_name)

In [None]:
sm_client.delete_endpoint_config(EndpointConfigName=ep_config_name_roberta)
sm_client.delete_endpoint_config(EndpointConfigName=ep_config_name_distilbert)
sm_client.delete_endpoint_config(EndpointConfigName=ep_config_name_roberta_mme)

In [None]:
cw.delete_alarms(AlarmNames=[error_alarm, latency_alarm])