# Amazon SageMaker Multi-Model Endpoints using TorchServe


---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

---

## Contents

With Amazon SageMaker multi-model endpoints, customers can create an endpoint that seamlessly hosts up to thousands of models. These endpoints are well suited to use cases where any one of many models, which can be served from a common inference container, needs to be called on-demand and where it is acceptable for infrequently invoked models to incur some additional latency. For applications which require consistently low inference latency, a traditional endpoint is still the best choice.

At a high level, Amazon SageMaker manages the loading and unloading of models for a multi-model endpoint, as they are needed. When an invocation request is made for a particular model, Amazon SageMaker routes the request to an instance assigned to that model, downloads the model artifacts from S3 onto that instance, and initiates loading of the model into the memory of the container. As soon as the loading is complete, Amazon SageMaker performs the requested invocation and returns the result. If the model is already loaded in memory on the selected instance, the downloading and loading steps are skipped, and the invocation is performed immediately.

This notebook uses SageMaker notebook instance conda_python3 kernel, demonstrates how to use TorchServe on SageMaker MME. In this example, there are 3 distinct models, each with its own set of dependencies, handler implementation and model configuration.

In [None]:
!python --version

In [None]:
!pip install numpy
!pip install pillow
!pip install -U sagemaker

In [None]:
# Python Built-Ins:
from datetime import datetime
import os
import json
import logging
import time

# External Dependencies:
import boto3
from botocore.exceptions import ClientError
import sagemaker
from sagemaker.multidatamodel import MultiDataModel
from sagemaker.model import Model

sess = boto3.Session()
sm = sess.client("sagemaker")
region = sess.region_name
account = boto3.client("sts").get_caller_identity().get("Account")

smsess = sagemaker.Session(boto_session=sess)
role = sagemaker.get_execution_role()

# Configuration:
bucket_name = smsess.default_bucket()
prefix = "torchserve"
output_path = f"s3://{bucket_name}/{prefix}/mme"
print(f"account={account}, region={region}, role={role}")

## Build a BYOD TorchServe Docker container and push it to Amazon ECR
You can follow this [instruction](https://medium.com/@samx81/how-to-move-docker-root-dir-to-ebs-on-aws-sagemaker-7d2560f7347d) to move Docker Root Dir to EBS if you got error "no space left" during building docker image.

In [None]:
# Use SageMaker PyTorch DLC as base image
baseimage = sagemaker.image_uris.retrieve(
 framework="pytorch",
 region=region,
 py_version="py310",
 image_scope="inference",
 version="2.0.0",
 instance_type="ml.g5.2xlarge",
)
print(baseimage)

In [None]:
# Install our own dependencies
!cat workspace/docker/Dockerfile

In [None]:
%%capture build_output

reponame = "torchserve-mme-demo"
versiontag = "genai-0.1"

# Build our own docker image
!cd workspace/docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}

In [None]:
if "Error response from daemon" in str(build_output):
 print(build_output)
 raise SystemExit("\n\n!!There was an error with the container build!!")
else:
 container = str(build_output).strip().split("\n")[-1]

print(container)

## Create Model Artifacts
This example creates a TorchServe model artifact for each model.
### Install torch-model-archiver

In [None]:
!pip install torch-model-archiver

### Model 1: Segment Anything Model(SAM)
A new AI model from Meta that can segment any object in any image with a single click. No additional training needed. We are downloading one of the checkpoints
#### Download Segment Anything Model(SAM)

In [None]:
model_file_name = "sam_vit_h_4b8939.pth"
download_path = f"https://huggingface.co/spaces/abhishek/StableSAM/resolve/main/{model_file_name}"

!wget $download_path -P workspace/sam

#### Implement customized handler
This step can be skipped if your model uses [TorchServe default handler](https://github.com/pytorch/serve/blob/ffa6847393cb7c36ae0122598152ca4614fe21f1/docs/default_handlers.md?plain=1#L1). Here we follow [TorchServe instruction](https://github.com/pytorch/serve/blob/ffa6847393cb7c36ae0122598152ca4614fe21f1/docs/custom_service.md?plain=1#L10) to create a customized handler for this model.

In [None]:
!cat workspace/sam/custom_handler.py

#### Config model

In [None]:
!cat workspace/sam/model-config.yaml

#### Create and upload `sam.tar.gz` file 

In [None]:
!torch-model-archiver --model-name sam --version 1.0 --serialized-file workspace/sam/sam_vit_h_4b8939.pth --handler workspace/sam/custom_handler.py --config-file workspace/sam/model-config.yaml --archive-format no-archive
!cd sam && tar cvzf sam.tar.gz .

In [None]:
!cd sam && aws s3 cp sam.tar.gz {output_path}/sam.tar.gz

### Model 2: Stable Diffusion In Paint (SD)
#### Import and Save Stable Diffusion Model

In [None]:
!pip install -U torch diffusers==0.13.0 transformers

In [None]:
import diffusers
import torch
import transformers

pipeline = diffusers.StableDiffusionInpaintPipeline.from_pretrained(
 "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16
)

In [None]:
sd_dir = "workspace/sd/model"
pipeline.save_pretrained(sd_dir)

#### Implement customized handler
This step can be skipped if your model uses [TorchServe default handler](https://github.com/pytorch/serve/blob/ffa6847393cb7c36ae0122598152ca4614fe21f1/docs/default_handlers.md?plain=1#L1). Here we follow [TorchServe instruction](https://github.com/pytorch/serve/blob/ffa6847393cb7c36ae0122598152ca4614fe21f1/docs/custom_service.md?plain=1#L10) to create a customized handler for this model.

In [None]:
!cat workspace/sd/custom_handler.py

#### Config model

In [None]:
!cat workspace/sd/model-config.yaml

#### Create `sd.tar.gz` file

In [None]:
!torch-model-archiver --model-name sd --version 1.0 --handler workspace/sd/custom_handler.py --extra-files workspace/sd/model --config-file workspace/sam/model-config.yaml --archive-format no-archive
!cd sd && tar cvzf sd.tar.gz .

### Model 3: Large Mask In Painting Model (Lama)
#### Download Pre-Trained Model

In [None]:
!pip install wldhx.yadisk-direct --quiet

In [None]:
!cd workspace/lama && curl -L $(yadisk-direct https://disk.yandex.ru/d/ouP6l8VJ0HpMZg) -o big-lama.zip && unzip big-lama.zip -d model

#### Clone Lama Repo

In [None]:
!cd workspace/lama && git clone https://github.com/advimman/lama.git lama-repo

#### Implement customized handler
This step can be skipped if your model uses [TorchServe default handler](https://github.com/pytorch/serve/blob/ffa6847393cb7c36ae0122598152ca4614fe21f1/docs/default_handlers.md?plain=1#L1). Here we follow [TorchServe instruction](https://github.com/pytorch/serve/blob/ffa6847393cb7c36ae0122598152ca4614fe21f1/docs/custom_service.md?plain=1#L10) to create a customized handler for this model.

In [None]:
!cat workspace/lama/custom_handler.py

#### Config model

In [None]:
!cat workspace/lama/model-config.yaml

#### Create `lama.tar.gz` file

In [None]:
!torch-model-archiver --model-name lama --version 1.0 --handler workspace/lama/custom_handler.py --extra-files workspace/lama/model,workspace/lama/lama-repo --config-file workspace/sam/model-config.yaml --archive-format no-archive
!cd lama && tar cvzf lama.tar.gz .

## Create the Multi-Model Endpoint with the SageMaker SDK

### Create the Amazon SageMaker MultiDataModel entity

We create the multi-model endpoint using the [```MultiDataModel```](https://sagemaker.readthedocs.io/en/stable/api/inference/multi_data_model.html) class.

You can create a MultiDataModel by directly passing in a `sagemaker.model.Model` object - in which case, the Endpoint will inherit information about the image to use, as well as any environmental variables, network isolation, etc., once the MultiDataModel is deployed.

In addition, a MultiDataModel can also be created without explicitly passing a `sagemaker.model.Model` object. Please refer to the documentation for additional details.

In [None]:
# This is where our MME will read models from on S3.
multi_model_s3uri = output_path

print(multi_model_s3uri)
model = Model(
 model_data=f"{multi_model_s3uri}/sam.tar.gz",
 image_uri=container,
 role=role,
 sagemaker_session=smsess,
 env={"TF_ENABLE_ONEDNN_OPTS": "0"},
)

mme = MultiDataModel(
 name="torchserve-mme-genai-" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
 model_data_prefix=multi_model_s3uri,
 model=model,
 sagemaker_session=smsess,
)
print(mme)

### Deploy the Multi-Model Endpoint

You need to consider the appropriate instance type and number of instances for the projected prediction workload across all the models you plan to host behind your multi-model endpoint. The number and size of the individual models will also drive memory requirements.

In [None]:
try:
 predictor.delete_endpoint(delete_endpoint_config=True)
 print("Deleting previous endpoint...")
 time.sleep(10)
except (NameError, ClientError):
 pass

mme.deploy(
 initial_instance_count=1,
 instance_type="ml.g5.2xlarge",
 serializer=sagemaker.serializers.JSONSerializer(),
 deserializer=sagemaker.deserializers.JSONDeserializer(),
)

### Our endpoint has launched! Let's look at what models are available to the endpoint!

By 'available', what we mean is, what model artifacts are currently stored under the S3 prefix we defined when setting up the `MultiDataModel` above i.e. `model_data_prefix`.

Currently, since we only have one artifact (i.e. `sam.tar.gz` files) stored under our defined S3 prefix.

In [None]:
# Only sam.tar.gz visible!
list(mme.list_models())

### Dynamically deploying models to the endpoint

The `.add_model()` method of the `MultiDataModel` will copy over our model artifacts from where they were initially stored, by training, to where our endpoint will source model artifacts for inference requests.

Note that we can continue using this method, as shown below, to dynamically deploy more models to our live endpoint as required!

`model_data_source` refers to the location of our model artifact (i.e. where it was deposited on S3 after training completed)

`model_data_path` is the **relative** path to the S3 prefix we specified above (i.e. `model_data_prefix`) where our endpoint will source models for inference requests. Since this is a **relative** path, we can simply pass the name of what we wish to call the model artifact at inference time.

In [None]:
models = ["sd/sd.tar.gz", "lama/lama.tar.gz"]
for model in models:
 mme.add_model(model_data_source=model)

### Our models are ready to invoke!

We can see that the S3 prefix we specified when setting up `MultiDataModel` now has model artifacts listed. As such, the endpoint can now serve up inference requests for these models.

In [None]:
list(mme.list_models())

## Get predictions from the endpoint

Recall that `mme.deploy()` returns a [Real Time Predictor](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/predictor.py#L35) that we saved in a variable called `predictor`.

That `predictor` can now be used as usual to request inference - but specifying which model to call:

In [None]:
predictor = sagemaker.predictor.Predictor(endpoint_name=mme.endpoint_name, sagemaker_session=smsess)
print(predictor)

### Model Segment Anything Inference Request

In [None]:
# sam payload
import base64
import json
import io
import numpy as np
from PIL import Image


def encode_image(img):
 # Convert the image to bytes
 with io.BytesIO() as output:
 img.save(output, format="JPEG")
 img_bytes = output.getvalue()

 return base64.b64encode(img_bytes).decode("utf-8")


img_file = "workspace/test_data/sample1.png"
img_bytes = None
with Image.open(img_file) as f:
 img_bytes = encode_image(f)

gen_args = json.dumps(dict(point_coords=[750, 500], point_labels=1, dilate_kernel_size=15))

payload = json.dumps({"image": img_bytes, "gen_args": gen_args}).encode("utf-8")

response = predictor.predict(data=payload, target_model="/sam.tar.gz")
encoded_masks_string = json.loads(response.decode("utf-8"))["generated_image"]
base64_bytes_masks = base64.b64decode(encoded_masks_string)

with Image.open(io.BytesIO(base64_bytes_masks)) as f:
 generated_image_rgb = f.convert("RGB")
 generated_image_rgb.show()

### Model Stable Diffusion In Paint Inference Request

In [None]:
# sd payload
import base64
import json
import io
import numpy as np
from PIL import Image


def encode_image(img):
 # Convert the image to bytes
 with io.BytesIO() as output:
 img.save(output, format="JPEG")
 img_bytes = output.getvalue()

 return base64.b64encode(img_bytes).decode("utf-8")


img_file = "workspace/test_data/sample1.png"
img_bytes = None
with Image.open(img_file) as f:
 img_bytes = encode_image(f)

mask_file = "workspace/test_data/sample1_mask.jpg"
mask_bytes = None
with Image.open(mask_file) as f:
 mask_bytes = encode_image(f)

prompt = "a teddy bear on a bench"
nprompt = "ugly"
gen_args = json.dumps(dict(num_inference_steps=50, guidance_scale=10, seed=1))

payload = json.dumps(
 {
 "image": img_bytes,
 "mask_image": mask_bytes,
 "prompt": prompt,
 "negative_prompt": nprompt,
 "gen_args": gen_args,
 }
).encode("utf-8")

response = predictor.predict(data=payload, target_model="/sd.tar.gz")
encoded_masks_string = json.loads(response.decode("utf-8"))["generated_image"]
base64_bytes_masks = base64.b64decode(encoded_masks_string)
with Image.open(io.BytesIO(base64_bytes_masks)) as f:
 generated_image_rgb = f.convert("RGB")
 generated_image_rgb.show()

### Large Mask In Painting Model Inference Request

In [None]:
# lama payload
import base64
import json
import io
import numpy as np
from PIL import Image


def encode_image(img):
 # Convert the image to bytes
 with io.BytesIO() as output:
 img.save(output, format="JPEG")
 img_bytes = output.getvalue()

 return base64.b64encode(img_bytes).decode("utf-8")


img_file = "workspace/test_data/sample1.png"
img_bytes = None
with Image.open(img_file) as f:
 img_bytes = encode_image(f)

mask_file = "workspace/test_data/sample1_mask.jpg"
mask_bytes = None
with Image.open(mask_file) as f:
 mask_bytes = encode_image(f)

payload = json.dumps(
 {
 "image": img_bytes,
 "mask_image": mask_bytes,
 "prompt": prompt,
 "negative_prompt": nprompt,
 "gen_args": gen_args,
 }
).encode("utf-8")

response = predictor.predict(data=payload, target_model="/lama.tar.gz")
encoded_masks_string = json.loads(response.decode("utf-8"))["generated_image"]
base64_bytes_masks = base64.b64decode(encoded_masks_string)
with Image.open(io.BytesIO(base64_bytes_masks)) as f:
 generated_image_rgb = f.convert("RGB")
 generated_image_rgb.show()

## Updating a model

To update a model, you would follow the same approach as above and add it as a new model. For example, `ModelA-2`.

You should avoid overwriting model artifacts in Amazon S3, because the old version of the model might still be loaded in the endpoint's running container(s) or on the storage volume of instances on the endpoint: This would lead invocations to still use the old version of the model.

Alternatively, you could stop the endpoint and re-deploy a fresh set of models.

## Clean up

Endpoints should be deleted when no longer in use, since (per the [SageMaker pricing page](https://aws.amazon.com/sagemaker/pricing/)) they're billed by time deployed. Here we'll also delete the endpoint configuration - to keep things tidy.

In [None]:
predictor.delete_endpoint(delete_endpoint_config=True)

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/inference|torchserve|mme-gpu|torchserve_multi_model_endpoint.ipynb)