## SageMaker Serverless Inference Provisioned Concurrency XGBoost

Amazon SageMaker Serverless Inference is a purpose-built inference option that makes it easy for customers to deploy and scale ML models. Serverless Inference is ideal for workloads which have idle periods between traffic spurts and can tolerate cold starts. Serverless endpoints also automatically launch compute resources and scale them in and out depending on traffic, eliminating the need to choose instance types or manage scaling policies. 

Serverless Inference however can be prone to cold-starts, as if your serverless endpoint does not receive traffic for a while and then your endpoint suddenly receives new requests, it can take some time for your endpoint to spin up the compute resources to process the requests. In this notebook we specifically explore Provisioned Concurrency, a new feature in Serverless Inference which can help mitigate this issue. With Provisioned Concurrency you can keep the compute enviroment initialized and reduce cold-start as your serverless endpoint is kept ready.

For this notebook we'll be working with the SageMaker XGBoost Algorithm to train a model and then deploy a serverless endpoint. We will be using the public S3 Abalone regression dataset for this example.

### Notebook Setting

- SageMaker Classic Notebook Instance: ml.m5.xlarge Notebook Instance & conda_python3 Kernel
- SageMaker Studio: Python 3 (Data Science)

### Setup

In [None]:
! pip install sagemaker botocore boto3 awscli --upgrade

In [None]:
! pwd

### SageMaker Setup

In [None]:
# Setup clients
import boto3

client = boto3.client(service_name="sagemaker")
runtime = boto3.client(service_name="sagemaker-runtime")

In [None]:
import sagemaker
from sagemaker.estimator import Estimator

boto_session = boto3.session.Session()
region = boto_session.region_name
print(region)

sagemaker_session = sagemaker.Session()
base_job_prefix = "xgboost-example"
role = sagemaker.get_execution_role()
print(role)

default_bucket = sagemaker_session.default_bucket()
s3_prefix = base_job_prefix

training_instance_type = "ml.m5.xlarge"

### Model Training

We will run a training job on the Abalone Regression Dataset with the Built-In XGBoost Algorithm. We will then utilized the trained model artifacts to deploy a Serverless Endpoint. If you have pre-trained model artifacts you can also deploy them directly to Serverless Inference and skip this portion.

In [None]:
# retrieve data
!aws s3 cp s3://sagemaker-sample-files/datasets/tabular/uci_abalone/train_csv/abalone_dataset1_train.csv .

In [None]:
# upload data to S3
!aws s3 cp abalone_dataset1_train.csv s3://{default_bucket}/xgboost-regression/train.csv

In [None]:
from sagemaker.inputs import TrainingInput

training_path = f"s3://{default_bucket}/xgboost-regression/train.csv"
train_input = TrainingInput(training_path, content_type="text/csv")

In [None]:
model_path = f"s3://{default_bucket}/{s3_prefix}/xgb_model"

# retrieve xgboost image
image_uri = sagemaker.image_uris.retrieve(
 framework="xgboost",
 region=region,
 version="1.0-1",
 py_version="py3",
 instance_type=training_instance_type,
)

# Configure Training Estimator
xgb_train = Estimator(
 image_uri=image_uri,
 instance_type=training_instance_type,
 instance_count=1,
 output_path=model_path,
 sagemaker_session=sagemaker_session,
 role=role,
)

# Set Hyperparameters
xgb_train.set_hyperparameters(
 objective="reg:linear",
 num_round=50,
 max_depth=5,
 eta=0.2,
 gamma=4,
 min_child_weight=6,
 subsample=0.7,
 silent=0,
)

In [None]:
# Fit model
xgb_train.fit({"train": train_input})

### Retrieve Model Artifacts

If you have a pre-trained model, provide these in a model.tar.gz as SageMaker expects a tarball format for the model.

In [None]:
# Retrieve model data from training job
model_artifacts = xgb_train.model_data
model_artifacts

### SageMaker Model Creation

Here we can specify the container image you are using as well as your model artifacts.

In [None]:
from time import gmtime, strftime

model_name = "xgboost-serverless-pc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("Model name: " + model_name)

# dummy environment variables
byo_container_env_vars = {"SAGEMAKER_CONTAINER_LOG_LEVEL": "20", "SOME_ENV_VAR": "myEnvVar"}

create_model_response = client.create_model(
 ModelName=model_name,
 Containers=[
 {
 "Image": image_uri,
 "Mode": "SingleModel",
 "ModelDataUrl": model_artifacts,
 "Environment": byo_container_env_vars,
 }
 ],
 ExecutionRoleArn=role,
)

print("Model Arn: " + create_model_response["ModelArn"])

### SageMaker Endpoint Configuration

Here you can specify your ProvisionedConcurrency parameter, ensure that this is less than or equal to the Maximum Concurrency that you specify for the endpoint. In this instance since we are comparing performance between a vanilla serverless endpoint and a provisioned serverless endpoint we will create two endpoint configs: One with Provisioned Concurrency enabled and the other without.

In [None]:
xgboost_epc_name_pc = "xgboost-serverless-epc-pc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
xgboost_epc_name_on_demand = "xgboost-serverless-epc-on-demand" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

endpoint_config_response_pc = client.create_endpoint_config(
 EndpointConfigName=xgboost_epc_name_pc,
 ProductionVariants=[
 {
 "VariantName": "byoVariant",
 "ModelName": model_name,
 "ServerlessConfig": {
 "MemorySizeInMB": 4096,
 "MaxConcurrency": 1,
 # Providing Provisioned Concurrency in EPC
 "ProvisionedConcurrency": 1
 },
 },
 ],
)

endpoint_config_response_on_demand = client.create_endpoint_config(
 EndpointConfigName=xgboost_epc_name_on_demand,
 ProductionVariants=[
 {
 "VariantName": "byoVariant",
 "ModelName": model_name,
 "ServerlessConfig": {
 "MemorySizeInMB": 4096,
 "MaxConcurrency": 1,
 },
 },
 ],
)

print("Endpoint Configuration Arn Provisioned Concurrency: " + endpoint_config_response_pc["EndpointConfigArn"])
print("Endpoint Configuration Arn On Demand Serverless: " + endpoint_config_response_on_demand["EndpointConfigArn"])

### Endpoint Creation

In [None]:
endpoint_name_pc = "xgboost-serverless-ep-pc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

create_endpoint_response = client.create_endpoint(
 EndpointName=endpoint_name_pc,
 EndpointConfigName=xgboost_epc_name_pc,
)

print("Endpoint Arn Provisioned Concurrency: " + create_endpoint_response["EndpointArn"])

In [None]:
# wait for endpoint to reach a terminal state (InService) using describe endpoint
import time

describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name_pc)

while describe_endpoint_response["EndpointStatus"] == "Creating":
 describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name_pc)
 print(describe_endpoint_response["EndpointStatus"])
 time.sleep(15)

describe_endpoint_response

In [None]:
endpoint_name_on_demand = "xgboost-serverless-ep-on-demand" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

create_endpoint_response = client.create_endpoint(
 EndpointName=endpoint_name_on_demand,
 EndpointConfigName=xgboost_epc_name_on_demand,
)

print("Endpoint Arn Provisioned Concurrency: " + create_endpoint_response["EndpointArn"])

In [None]:
# wait for endpoint to reach a terminal state (InService) using describe endpoint
import time

describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name_on_demand)

while describe_endpoint_response["EndpointStatus"] == "Creating":
 describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name_on_demand)
 print(describe_endpoint_response["EndpointStatus"])
 time.sleep(15)

describe_endpoint_response

### Sample Inference

In [None]:
%%time

#On Demand Serverless Endpoint Test
response = runtime.invoke_endpoint(
 EndpointName=endpoint_name_on_demand,
 Body=b".345,0.224414,.131102,0.042329,.279923,-0.110329,-0.099358,0.0",
 ContentType="text/csv",
)

print(response["Body"].read())

In [None]:
%%time

#Provisioned Endpoint Test
response = runtime.invoke_endpoint(
 EndpointName=endpoint_name_pc,
 Body=b".345,0.224414,.131102,0.042329,.279923,-0.110329,-0.099358,0.0",
 ContentType="text/csv",
)

print(response["Body"].read())

### Evaluate Performance of PC vs On Demand Serverless Inference

Note that the following cell will take an hour to run. In this block we will wait 10 minutes between each request to compare cold-start times between a On Demand endpoint and the endpoint we have with Provisioned Concurrency enabled.

In [None]:
import time
import numpy as np
print("Testing cold start for serverless inference with PC vs no PC")

pc_times = []
non_pc_times = []

# ~50 minutes
for i in range(5):
 time.sleep(600)
 start_pc = time.time()
 pc_response = runtime.invoke_endpoint(
 EndpointName=endpoint_name_pc,
 Body=b".345,0.224414,.131102,0.042329,.279923,-0.110329,-0.099358,0.0",
 ContentType="text/csv",
 )
 end_pc = time.time() - start_pc
 pc_times.append(end_pc)

 start_no_pc = time.time()
 response = runtime.invoke_endpoint(
 EndpointName=endpoint_name_on_demand,
 Body=b".345,0.224414,.131102,0.042329,.279923,-0.110329,-0.099358,0.0",
 ContentType="text/csv",
 )
 end_no_pc = time.time() - start_no_pc
 non_pc_times.append(end_no_pc)

pc_cold_start = np.mean(pc_times)
non_pc_cold_start = np.mean(non_pc_times)

print("Provisioned Concurrency Serverless Inference Average Cold Start: {}".format(pc_cold_start))
print("On Demand Serverless Inference Average Cold Start: {}".format(non_pc_cold_start))

In [None]:
import matplotlib.pyplot as plt

data = {'PC Cold-Start':pc_cold_start, 'On Demand Cold-Start':non_pc_cold_start}
cold_starts = list(data.keys())
values = list(data.values())
 
fig = plt.figure(figsize = (10, 5))
 
# creating the bar plot
plt.bar(cold_starts, values, color ='maroon',
 width = 0.4)
 
plt.xlabel("Serverless Inference Options")
plt.ylabel("Cold-Start Average Times")
plt.title("Provisioned Concurrency vs On Demand Serverless Inference Cold-Start Times")
plt.show()

### Cleanup

In [None]:
client.delete_model(ModelName=model_name)
client.delete_endpoint_config(EndpointConfigName=xgboost_epc_name_pc)
client.delete_endpoint_config(EndpointConfigName=xgboost_epc_name_on_demand)
client.delete_endpoint(EndpointName=endpoint_name_pc)
client.delete_endpoint(EndpointName=endpoint_name_on_demand)