# Deploying Serverless Endpoints From SageMaker Model Registry

## SageMaker XGBoost Algorithm Regression Example

Amazon SageMaker Serverless Inference is a purpose-built inference option that makes it easy for customers to deploy and scale ML models. Serverless Inference is ideal for workloads which have idle periods between traffic spurts and can tolerate cold starts. Serverless endpoints also automatically launch compute resources and scale them in and out depending on traffic, eliminating the need to choose instance types or manage scaling policies.

[SageMaker Model Registry](https://docs.aws.amazon.com/sagemaker/latest/dg/model-registry.html) can be used to catalog and manage different model versions. Model Registry now supports deploying registered models to serverless endpoints. For this notebook we will take the existing [XGBoost Serverless example](https://github.com/aws/amazon-sagemaker-examples/blob/main/serverless-inference/Serverless-Inference-Walkthrough.ipynb) and integrate with the Model Registry. From there we will take our trained model and deploy it to a serverless endpoint using the Boto3 Python SDK. Note that there is not Model Registry support for the SageMaker SDK with serverless endpoints at the moment.

` Please reach out to Yuyao Zhang ozhang@amazon.com or Melanie Li mmelli@amazon.com for any issue or questions`

<b>Notebook Setting</b>
- <b>SageMaker Studio</b>: Python 3 (Data Science)
- <b>Regions Available</b>: SageMaker Serverless Inference is currently available in the following regions in preview: US East (Northern Virginia), US East (Ohio), US West (Oregon), EU (Ireland), Asia Pacific (Tokyo) and Asia Pacific (Sydney). After general availability it should be available in all commercial regions. To verify availability stay up to date with the [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/serverless-endpoints.html) which will reflect all supported regions.



## Table of Contents
- Setup
- Model Training
- Model Registry
- Deployment
    - Model Creation
    - Endpoint Configuration Creation
    - Serverless Endpoint Creation
    - Endpoint Invocation
- Cleanup


## Setup

For testing you need to properly configure your Notebook Role to have <b>SageMaker Full Access</b>.

In [None]:
! pip install sagemaker botocore boto3 awscli --upgrade

In [None]:
# Setup clients
import boto3

client = boto3.client(service_name="sagemaker")
runtime = boto3.client(service_name="sagemaker-runtime")

In [None]:
import sagemaker
from sagemaker.estimator import Estimator

boto_session = boto3.session.Session()
region = boto_session.region_name
print(region)

sagemaker_session = sagemaker.Session()
base_job_prefix = "xgboost-example"
role = sagemaker.get_execution_role()
print(role)

default_bucket = sagemaker_session.default_bucket()
s3_prefix = base_job_prefix

training_instance_type = "ml.m5.xlarge"

In [None]:
# retrieve data
! curl https://sagemaker-sample-files.s3.amazonaws.com/datasets/tabular/uci_abalone/train_csv/abalone_dataset1_train.csv > abalone_dataset1_train.csv

In [None]:
# upload data to S3
!aws s3 cp abalone_dataset1_train.csv s3://{default_bucket}/xgboost-regression/train.csv

## Model Training

Now, we train an ML model using the [SageMaker XGBoost Algorithm](https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html). In this example, we use a SageMaker-provided XGBoost container image and configure an estimator to train our model.

In [None]:
from sagemaker.inputs import TrainingInput

training_path = f"s3://{default_bucket}/xgboost-regression/train.csv"
train_input = TrainingInput(training_path, content_type="text/csv")

In [None]:
model_path = f"s3://{default_bucket}/{s3_prefix}/xgb_model"

# retrieve xgboost image
image_uri = sagemaker.image_uris.retrieve(
    framework="xgboost",
    region=region,
    version="1.0-1",
    py_version="py3",
    instance_type=training_instance_type,
)

# Configure Training Estimator
xgb_train = Estimator(
    image_uri=image_uri,
    instance_type=training_instance_type,
    instance_count=1,
    output_path=model_path,
    sagemaker_session=sagemaker_session,
    role=role,
    output_kms_key="arn:aws:kms:us-east-1:631450739534:key/db5f4c09-996e-4db9-bdde-9745f44d7b06",
)

# Set Hyperparameters
xgb_train.set_hyperparameters(
    objective="reg:linear",
    num_round=50,
    max_depth=5,
    eta=0.2,
    gamma=4,
    min_child_weight=6,
    subsample=0.7,
    silent=0,
)

In [None]:
# Fit model
xgb_train.fit({"train": train_input})

In [None]:
# Retrieve model data from training job
model_artifacts = xgb_train.model_data
model_artifacts

In [None]:
# model_artifacts = "s3://sagemaker-us-east-1-631450739534/xgboost-example/xgb_model/sagemaker-xgboost-2022-05-15-09-46-44-963/output1/model.tar.gz"

## Model Registry

In [None]:
# Create a Model Package Group: https://docs.aws.amazon.com/sagemaker/latest/dg/model-registry-model-group.html
import time
from time import gmtime, strftime

model_package_group_name = "xgboost-abalone" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
model_package_group_input_dict = {
    "ModelPackageGroupName": model_package_group_name,
    "ModelPackageGroupDescription": "Model package group for xgboost regression model with Abalone dataset",
}

create_model_pacakge_group_response = client.create_model_package_group(
    **model_package_group_input_dict
)
print(
    "ModelPackageGroup Arn : {}".format(create_model_pacakge_group_response["ModelPackageGroupArn"])
)

In [None]:
model_package_group_arn = create_model_pacakge_group_response["ModelPackageGroupArn"]
modelpackage_inference_specification = {
    "InferenceSpecification": {
        "Containers": [
            {
                "Image": image_uri,
            }
        ],
        "SupportedContentTypes": ["text/csv"],
        "SupportedResponseMIMETypes": ["text/csv"],
    }
}

# Specify the model source
model_url = model_artifacts

# Specify the model data
modelpackage_inference_specification["InferenceSpecification"]["Containers"][0][
    "ModelDataUrl"
] = model_url

create_model_package_input_dict = {
    "ModelPackageGroupName": model_package_group_arn,
    "ModelPackageDescription": "Model for regression with the Abalone dataset",
    "ModelApprovalStatus": "PendingManualApproval",
}
create_model_package_input_dict.update(modelpackage_inference_specification)

# Create cross-account model package
create_mode_package_response = client.create_model_package(**create_model_package_input_dict)
model_package_arn = create_mode_package_response["ModelPackageArn"]
print("ModelPackage Version ARN : {}".format(model_package_arn))

In [None]:
client.list_model_packages(ModelPackageGroupName=model_package_group_name)

In [None]:
model_package_arn = client.list_model_packages(ModelPackageGroupName=model_package_group_name)[
    "ModelPackageSummaryList"
][0]["ModelPackageArn"]
model_package_arn

In [None]:
client.describe_model_package(ModelPackageName=model_package_arn)

In [None]:
# Approve the model package
model_package_update_input_dict = {
    "ModelPackageArn": model_package_arn,
    "ModelApprovalStatus": "Approved",
}
model_package_update_response = client.update_model_package(**model_package_update_input_dict)
print(model_package_update_response)

In [None]:
import json

# The cross-account id to grant access to
cross_account_id = "682604156941"

account = "631450739534"

# Create a policy for accessing the S3 bucket
bucket_policy = {
    'Version': '2012-10-17',
    'Statement': [{
        'Sid': 'AddPerm',
        'Effect': 'Allow',
        'Principal': {
            'AWS': f'arn:aws:iam::{cross_account_id}:root'
        },
        'Action': 's3:*',
        'Resource': f'arn:aws:s3:::{default_bucket}/*'
    }]
}

# Convert the policy from JSON dict to string
bucket_policy = json.dumps(bucket_policy)

# Set the new policy
s3 = boto3.client('s3')
response = s3.put_bucket_policy(
    Bucket = default_bucket,
    Policy = bucket_policy)

# Create the KMS grant for encryption in the source account to the
# model registry account model package group
kms_client = boto3.client('kms')

response = kms_client.create_grant(
    GranteePrincipal=f'arn:aws:iam::{cross_account_id}:root',
    KeyId="db5f4c09-996e-4db9-bdde-9745f44d7b06",
    Operations=[
        'Decrypt',
        'GenerateDataKey',
    ],
)

# 3. Create a policy for access to the model package group.
model_package_group_policy = {
    'Version': '2012-10-17',
    'Statement': [{
        'Sid': 'AddPermModelPackageGroup',
        'Effect': 'Allow',
        'Principal': {
            'AWS': f'arn:aws:iam::{cross_account_id}:root'
        },
        'Action': ['sagemaker:DescribeModelPackageGroup'],
        'Resource': f'arn:aws:sagemaker:{region}:{account}:model-package-group/{model_package_group_name}'
    },{
        'Sid': 'AddPermModelPackageVersion',
        'Effect': 'Allow',
        'Principal': {
            'AWS': f'arn:aws:iam::{cross_account_id}:root'
        },
        'Action': ["sagemaker:DescribeModelPackage",
                   "sagemaker:ListModelPackages",
                   "sagemaker:UpdateModelPackage",
                   "sagemaker:CreateModel"],
        'Resource': f'arn:aws:sagemaker:{region}:{account}:model-package/{model_package_group_name}/*'
    }]
}

# Convert the policy from JSON dict to string
model_package_group_policy = json.dumps(model_package_group_policy)

# Set the policy to the model package group
response = client.put_model_package_group_policy(
    ModelPackageGroupName = model_package_group_name,
    ResourcePolicy = model_package_group_policy)

print('ModelPackageGroupArn : {}'.format(create_model_pacakge_group_response['ModelPackageGroupArn']))
print("First Versioned ModelPackageArn: " + model_package_arn)

print("Success! You are all set to proceed for cross-account deployment.")

## Deployment

### Model Creation

In [None]:
model_name = "xgboost-serverless-model" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("Model name : {}".format(model_name))
container_list = [{"ModelPackageName": model_package_arn}]

create_model_response = client.create_model(
    ModelName=model_name, ExecutionRoleArn=role, Containers=container_list
)
print("Model arn : {}".format(create_model_response["ModelArn"]))

### Endpoint Configuration Creation
This is where you can adjust the [Serverless Configuration](https://docs.aws.amazon.com/sagemaker/latest/dg/serverless-endpoints-create.html) for your endpoint. The current max concurrent invocations for a single endpoint, known as MaxConcurrency, can be any value from 1 to 50, and MemorySize can be any of the following: 1024 MB, 2048 MB, 3072 MB, 4096 MB, 5120 MB, or 6144 MB.

In [None]:
endpoint_config_name = "xgboost-serverless-epc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(endpoint_config_name)
create_endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "ServerlessConfig": {"MemorySizeInMB": 1024, "MaxConcurrency": 10},
            "ModelName": model_name,
            "VariantName": "AllTraffic",
        }
    ],
)
print("Endpoint Configuration Arn: " + create_endpoint_config_response["EndpointConfigArn"])

### Endpoint Creation
Now that we have an endpoint configuration, we can create a serverless endpoint and deploy our model to it. When creating the endpoint, provide the name of your endpoint configuration and a name for the new endpoint.

In [None]:
endpoint_name = "xgboost-serverless-ep" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("EndpointName={}".format(endpoint_name))

create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
print(create_endpoint_response["EndpointArn"])

Wait until the endpoint status is InService before invoking the endpoint.

In [None]:
import time

describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)

while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)

describe_endpoint_response

In [None]:
from sagemaker.lineage.artifact import Artifact, ModelArtifact, DatasetArtifact
artifact1 = Artifact.create(
    source_uri="arn:aws:sagemaker:us-east-1:631450739534:model/xgboost-serverless-model2022-05-15-11-51-34",
    source_types=[{
        "SourceIdType":"Custom",
        "Value":"xgboost-serverless-model2022-05-15-11-51-34",
    }],
    artifact_type="Model",
)

In [None]:
!pip install pyvis -q

In [None]:
from pyvis.network import Network
import os

In [None]:
class Visualizer:
    def __init__(self):
        self.directory = "generated"
        if not os.path.exists(self.directory):
            os.makedirs(self.directory)

    def render(self, query_lineage_response, scenario_name):
        net = self.get_network()

        for vertex in query_lineage_response["Vertices"]:
            arn = vertex["Arn"]
            if "Type" in vertex:
                label = vertex["Type"]
            else:
                label = None
            lineage_type = vertex["LineageType"]
            name = self.get_name(arn, label)
            title = self.get_title(arn, label, lineage_type)
            net.add_node(vertex["Arn"], label=name, title=title, shape="box", physics=False)

        for edge in query_lineage_response["Edges"]:
            source = edge["SourceArn"]
            dest = edge["DestinationArn"]
            net.add_edge(dest, source)

        return net.show(f"{self.directory}/{scenario_name}.html")

    def get_title(self, arn, label, lineage_type):
        return f"Arn: {arn}\nType: {label}\nLineage Type: {lineage_type}"

    def get_name(self, arn, type):
        print(arn)
        name = arn.split("/")[1]+' '+type
        return name

    def get_network(self):
        net = Network(height="800px", width="1000px", directed=True, notebook=True)
        net.set_options(
            """
            var options = {
              "nodes": {
                "borderWidth": 1,
                "shadow": {
                  "enabled": true
                },
                "shapeProperties": {
                  "borderRadius": 0
                },
                "size": 40,
                "shape": "circle"
              },
              "edges": {
                "arrows": {
                  "to": {
                    "enabled": true
                  }
                },
                "color": {
                  "inherit": true
                },
                "smooth": false
              },
              "layout": {
                "hierarchical": {
                  "enabled": false,
                  "direction": "LR",
                  "sortMethod": "directed"
                }
              }
            }
        """
        )
        return net

In [None]:
from sagemaker.lineage import context, artifact, association, action
from sagemaker.lineage.artifact import Artifact, ModelArtifact, DatasetArtifact
from sagemaker.lineage.query import (
    LineageQuery,
    LineageFilter,
    LineageSourceEnum,
    LineageEntityEnum,
    LineageQueryDirectionEnum,
)
import json

In [None]:
model_artifact_summary = list(Artifact.list(source_uri=model_package_arn))[0]
model_artifact = ModelArtifact.load(artifact_arn=model_artifact_summary.artifact_arn)
query_filter = LineageFilter(
    entities=[LineageEntityEnum.CONTEXT],
    sources=[LineageSourceEnum.ENDPOINT, LineageSourceEnum.MODEL],
)



In [None]:
query_result = LineageQuery(sagemaker_session).query(
    start_arns=[model_artifact.artifact_arn],  # Model is the starting artifact
    query_filter=query_filter,
    # Find all the entities that descend from the model, i.e. the endpoint
    direction=LineageQueryDirectionEnum.DESCENDANTS,
    include_edges=True,
)
associations = []
for vertex in query_result.vertices:
    associations.append(vertex.__dict__)
print(associations)

In [None]:
action_resource = action.Action.create(
    action_name=model_name,
    source_uri=model_artifacts,
    source_type="Model",
    description="createModel",
    properties={"model":model_name,"accountId":"631450739534"},
    action_type="createModel"
)

In [None]:
associate1 = association.Association.create(
    source_arn=model_artifact_summary.artifact_arn,
    destination_arn=action_resource.action_arn,
    association_type="AssociatedWith"
)

In [None]:
query_response = client.query_lineage(
    StartArns=[model_artifact.artifact_arn], Direction="Both", IncludeEdges=True
)

viz = Visualizer()
viz.render(query_response, "ModelPackageVersion")

In [None]:
print(json.dumps(query_response, indent=2))

In [None]:
# associate1.delete()

In [None]:
# action_resource.delete(disassociate=True)

In [None]:
!aws sagemaker delete-association --source-arn "arn:aws:sagemaker:us-east-1:631450739534:artifact/bcbc2dfc504f748168cc1b2204e255c9" --destination-arn 'arn:aws:sagemaker:us-east-1:631450739534:action/xgboost-serverless-model2022-05-15-10-43-09'

In [None]:
!aws sagemaker delete-action --action-name "xgboost-serverless-model2022-05-15-10-43-09"

## Cross account model creation

In [None]:
model_name = "xgboost-serverless-model2022-05-15-10-43-09"
account_id = "682604156941"
source_uri = model_name + account_id
source_uri

In [None]:

cross_act_action_resource = action.Action.create(
    action_name=model_name,
    source_uri=source_uri,
    source_type="Model",
    description="createModel",
    properties={
        "model":model_name,
        "accountId":account_id,
        "endpoint":"xgboost-serverless-ep2022-05-15-10-43-11",
    },
    action_type="ModelDeployment"
)

In [None]:
from sagemaker.lineage.artifact import Artifact, ModelArtifact, DatasetArtifact
from sagemaker.lineage.query import (
    LineageQuery,
    LineageFilter,
    LineageSourceEnum,
    LineageEntityEnum,
    LineageQueryDirectionEnum,
)

model_artifact_summary = list(Artifact.list(source_uri=model_package_arn))[0]
model_artifact_summary.artifact_arn

In [None]:
association.Association.create(
    source_arn=model_artifact_summary.artifact_arn,
    destination_arn=cross_act_action_resource.action_arn,
    association_type="ContributedTo"
)

## Cleanup
Delete any resources you created in this notebook that you no longer wish to use.

In [None]:
client.delete_model(ModelName=model_name)
client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
client.delete_endpoint(EndpointName=endpoint_name)