# Deploy Stable Diffusion on a SageMaker GPU Multi-Model Endpoint with Triton

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.



---

In this notebook we will deploy multiple variations of Stable Diffusion on a SageMaker Multi-Model GPU Endpoint (MME GPU) powered by NVIDIA Triton Inference Server.
> ⚠ **Warning**: This notebook requires a minimum of an `ml.m5.large` instance to build the conda environment required for hosting the Stable Diffusion models. 

Skip to:
1. [Installs and imports](#installs)
2. [Download pretrained model](#modelartifact)
3. [Packaging a conda environment](#condaenv)
4. [Deploy to SageMaker Real-Time Endpoint](#deploy)
6. [Query Models](#query)
7. [Clean up](#cleanup)


### Part 1 - Installs and imports 

In [None]:
%pip install -U sagemaker pillow huggingface-hub conda-pack

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role

import time
import json
from PIL import Image
import base64
from io import BytesIO
import numpy as np

from utils import download_model

from IPython.display import display

# variables
s3_client = boto3.client("s3")
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

# sagemaker variables
role = get_execution_role()
sm_client = boto3.client(service_name="sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")
sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
bucket = sagemaker_session.default_bucket()
prefix = "stable-diffusion-mme"

### Part 2 - Save pretrained model 

The `models` directory contains the inference code and the Triton configuration file for each of the Stable Diffusion models. In addition to these, we also need to download the pretrained model weights and save them to ther respective subdirectory within `models` directory. Once we have these downloaded, we can package the inference code and the model weights into a tarball and upload it to S3.

In [None]:
models_local_path = {
 "stabilityai/stable-diffusion-2-1-base": "models/sd_base/1/checkpoint",
 "stabilityai/stable-diffusion-2-depth": "models/sd_depth/1/checkpoint",
 "stabilityai/stable-diffusion-2-inpainting": "models/sd_inpaint/1/checkpoint",
 "stabilityai/stable-diffusion-x4-upscaler": "models/sd_upscale/1/checkpoint",
}

for model_name, model_local_path in models_local_path.items():
 download_model(model_name, model_local_path)

### Part 3 - Packaging a conda environment, extending Sagemaker Triton container 

When using the Triton Python backend (which our Stable Diffusion model will run on), you can include your own environment and dependencies. The recommended way to do this is to use [conda pack](https://conda.github.io/conda-pack/) to generate a conda environment archive in `tar.gz` format, and point to it in the `config.pbtxt` file of the models that should use it, adding the snippet: 

```
parameters: {
 key: "EXECUTION_ENV_PATH",
 value: {string_value: "path_to_your_env.tar.gz"}
}

```
You can use a different environment per model, or the same for all models (read more on this [here](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments)). Since the all of the models that we'll be deploying have the same set of environment requirements, we will create a single conda environment and will use a Python backend to copy that environment into a location where it can be accessed by all models.

> ⚠ **Warning**: The approach for a creating a shared conda environment highlighted here is limited to a single instance deployment only. In the event of auto-scaling, there is no guarantee that the new instance will have the conda environment configured. Since the conda environment for hosting Stable Diffusion models is quite large the recommended approach for production deployments is to create shared environment by extending the Triton Inference Image. 

Let's start by creating the conda environment with the necessary dependencies; running these cells will output a `sd_env.tar.gz` file.

In [None]:
%%writefile environment.yml
name: mme_env
dependencies:
 - python=3.8
 - pip
 - pip:
 - numpy
 - torch --extra-index-url https://download.pytorch.org/whl/cu118
 - accelerate
 - transformers
 - diffusers
 - xformers
 - conda-pack

Now we can create the environment using the above environment yaml spec

🛈 It could take up to 5 min to create the conda environment. Make sure you are running this notebook in an `ml.m5.large` instance or above

In [None]:
!conda env create -f environment.yml

In [None]:
!conda pack -n mme_env -o models/setup_conda/sd_env.tar.gz

### Part 4 - Deploy endpoint 

Now, we get the correct URI for the SageMaker Triton container image. Check out all the available Deep Learning Container images that AWS maintains [here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md). 

In [None]:
# account mapping for SageMaker Triton Image
account_id_map = {
 "us-east-1": "785573368785",
 "us-east-2": "007439368137",
 "us-west-1": "710691900526",
 "us-west-2": "301217895009",
 "eu-west-1": "802834080501",
 "eu-west-2": "205493899709",
 "eu-west-3": "254080097072",
 "eu-north-1": "601324751636",
 "eu-south-1": "966458181534",
 "eu-central-1": "746233611703",
 "ap-east-1": "110948597952",
 "ap-south-1": "763008648453",
 "ap-northeast-1": "941853720454",
 "ap-northeast-2": "151534178276",
 "ap-southeast-1": "324986816169",
 "ap-southeast-2": "355873309152",
 "cn-northwest-1": "474822919863",
 "cn-north-1": "472730292857",
 "sa-east-1": "756306329178",
 "ca-central-1": "464438896020",
 "me-south-1": "836785723513",
 "af-south-1": "774647643957",
}


region = boto3.Session().region_name
if region not in account_id_map.keys():
 raise ("UNSUPPORTED REGION")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
mme_triton_image_uri = (
 "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.12-py3".format(
 account_id=account_id_map[region], region=region, base=base
 )
)

The next step is to package the model subdirectories and weights into individual tarballs and upload them to S3. This process can take a about 10 to 15 minutes.

In [None]:
from pathlib import Path

model_root_path = Path("./models")
model_dirs = list(model_root_path.glob("*"))

In [None]:
model_upload_paths = {}
for model_path in model_dirs:
 model_name = model_path.name
 tar_name = model_path.name + ".tar.gz"
 !tar -C $model_root_path -czvf $tar_name $model_name
 model_upload_paths[model_name] = sagemaker_session.upload_data(path=tar_name, bucket=bucket, key_prefix=prefix)
 !rm $tar_name

We are now ready to configure and deploy the multi-model endpoint

In [None]:
model_data_url = f"s3://{bucket}/{prefix}/" # s3 location where models are stored
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
 "Image": mme_triton_image_uri,
 "ModelDataUrl": model_data_url,
 "Mode": "MultiModel",
}

In [None]:
sm_model_name = f"{prefix}-mdl-{ts}"

create_model_response = sm_client.create_model(
 ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Create a SageMaker endpoint configuration.

In [None]:
endpoint_config_name = f"{prefix}-epc-{ts}"
instance_type = "ml.g5.xlarge"

create_endpoint_config_response = sm_client.create_endpoint_config(
 EndpointConfigName=endpoint_config_name,
 ProductionVariants=[
 {
 "InstanceType": instance_type,
 "InitialVariantWeight": 1,
 "InitialInstanceCount": 1,
 "ModelName": sm_model_name,
 "VariantName": "AllTraffic",
 }
 ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Create the endpoint, and wait for it to transition to `InService` state.

In [None]:
endpoint_name = f"{prefix}-ep-{ts}"

create_endpoint_response = sm_client.create_endpoint(
 EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
 time.sleep(30)
 resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
 status = resp["EndpointStatus"]
 print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### Query models 
The endpoint is now deployed and we can query the individual models

Prior to invoking any of the Stable Diffusion Models, we first invoke the `setup_conda` which will copy the conda environment into a directory that can be shared with all the other models. Refer to the [model.py](./models/setup_conda/1/model.py) file in the `models/setup_conda/1` directory for more details on the implementation.

In [None]:
# invoke the setup_conda model to create the shared conda environment

payload = {
 "inputs": [
 {
 "name": "TEXT",
 "shape": [1],
 "datatype": "BYTES",
 "data": ["hello"], # dummy data not used by the model
 }
 ]
}

response = runtime_sm_client.invoke_endpoint(
 EndpointName=endpoint_name,
 ContentType="application/octet-stream",
 Body=json.dumps(payload),
 TargetModel="setup_conda.tar.gz",
)

In [None]:
# helper functions to encode and decode images
def encode_image(image):
 buffer = BytesIO()
 image.save(buffer, format="JPEG")
 img_str = base64.b64encode(buffer.getvalue())

 return img_str


def decode_image(img):
 buff = BytesIO(base64.b64decode(img.encode("utf8")))
 image = Image.open(buff)
 return image

In [None]:
inputs = dict(
 prompt="Infinity pool on top of a high rise overlooking Central Park",
 negative_prompt="blur, signature, low detail, low quality",
 gen_args=json.dumps(dict(num_inference_steps=50, guidance_scale=8)),
)

payload = {
 "inputs": [
 {"name": name, "shape": [1, 1], "datatype": "BYTES", "data": [data]}
 for name, data in inputs.items()
 ]
}

response = runtime_sm_client.invoke_endpoint(
 EndpointName=endpoint_name,
 ContentType="application/octet-stream",
 Body=json.dumps(payload),
 TargetModel="sd_base.tar.gz",
)
output = json.loads(response["Body"].read().decode("utf8"))["outputs"]
original_image = decode_image(output[0]["data"][0])
original_image

Let's take the output from the Standard Model and modify it using the depth model.

We can use the same model to change the style of the original image into an oil panting or change the setting from New York City Central Park to the Yellowstone National Park while preserving the orientation of the original image

In [None]:
input_image = encode_image(original_image).decode("utf8")

inputs = dict(
 prompt="highly detailed oil painting of an inifinity pool overlooking central park",
 image=input_image,
 gen_args=json.dumps(dict(num_inference_steps=50, strength=0.8)),
)


payload = {
 "inputs": [
 {"name": name, "shape": [1, 1], "datatype": "BYTES", "data": [data]}
 for name, data in inputs.items()
 ]
}

response = runtime_sm_client.invoke_endpoint(
 EndpointName=endpoint_name,
 ContentType="application/octet-stream",
 Body=json.dumps(payload),
 TargetModel="sd_depth.tar.gz",
)
output = json.loads(response["Body"].read().decode("utf8"))["outputs"]
oil_painting = decode_image(output[0]["data"][0])


inputs = dict(
 prompt="Infinity pool perched on a cliff overlooking Yellowstone National Park ",
 image=input_image,
 gen_args=json.dumps(dict(num_inference_steps=50, strength=0.8)),
)


payload = {
 "inputs": [
 {"name": name, "shape": [1, 1], "datatype": "BYTES", "data": [data]}
 for name, data in inputs.items()
 ]
}

response = runtime_sm_client.invoke_endpoint(
 EndpointName=endpoint_name,
 ContentType="application/octet-stream",
 Body=json.dumps(payload),
 TargetModel="sd_depth.tar.gz",
)
output = json.loads(response["Body"].read().decode("utf8"))["outputs"]
rocky_mountains = decode_image(output[0]["data"][0])


print("Original image")
display(original_image)

print("Oil painting")
display(oil_painting)

print("Yellowstone")
display(rocky_mountains)

In [None]:
source_image = Image.open("sample_images/bertrand-gabioud.png")

image = encode_image(source_image).decode("utf8")
mask_image = encode_image(Image.open("sample_images/bertrand-gabioud-mask.png")).decode("utf8")
inputs = dict(
 prompt="building, facade, paint, windows",
 image=image,
 mask_image=mask_image,
 negative_prompt="tree, obstruction, sky, clouds",
 gen_args=json.dumps(dict(num_inference_steps=50, guidance_scale=10)),
)


payload = {
 "inputs": [
 {"name": name, "shape": [1, 1], "datatype": "BYTES", "data": [data]}
 for name, data in inputs.items()
 ]
}

response = runtime_sm_client.invoke_endpoint(
 EndpointName=endpoint_name,
 ContentType="application/octet-stream",
 Body=json.dumps(payload),
 TargetModel="sd_inpaint.tar.gz",
)
output = json.loads(response["Body"].read().decode("utf8"))["outputs"]
print("source image")
display(source_image)

print("filled image")
display(decode_image(output[0]["data"][0]))

For the final example we will downsize our original output image from 512x512 to 128x128. We will then use the upscaling model to upscale the image back to its original 512 resolution

In [None]:
low_res_image = original_image.resize((128, 128))
inputs = dict(
 prompt="Infinity pool on top of a high rise overlooking Central Park",
 image=encode_image(low_res_image).decode("utf8"),
)

payload = {
 "inputs": [
 {"name": name, "shape": [1, 1], "datatype": "BYTES", "data": [data]}
 for name, data in inputs.items()
 ]
}

response = runtime_sm_client.invoke_endpoint(
 EndpointName=endpoint_name,
 ContentType="application/octet-stream",
 Body=json.dumps(payload),
 TargetModel="sd_upscale.tar.gz",
)
output = json.loads(response["Body"].read().decode("utf8"))["outputs"]
upscaled_image = decode_image(output[0]["data"][0])

print("Low res image")
display(low_res_image.resize((512, 512)))

print("Upscaled image")
display(upscaled_image)

## Clean up 

In [None]:
sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=sm_model_name)

In [None]:
#delete models in respective paths
for model_name, model_local_path in models_local_path.items():
 !rm -rf $model_local_path

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.































