# Deploy Stable Diffusion on SageMaker with Triton Business Logic Scripting (BLS)

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

---

In this notebook we will take most of the [example](https://github.com/triton-inference-server/server/tree/main/docs/examples/stable_diffusion) to host Stable Diffusion on Triton Inference Server provided by NVIDIA and adapt it to SageMaker.

[Business Logic Scripting (BLS)](https://github.com/triton-inference-server/python_backend#business-logic-scripting) is a Triton Inference Server feature that allows you to create complex inference logic, where loops, conditionals, data-dependent control flow and other custom logic needs to be intertwined with model execution. From within a Python script that runs on Triton's [Python backend](https://github.com/triton-inference-server/python_backend), you can run some of the required inference steps (light processing, even ML models that are not fit to be run on framework-specific backends), but also call other models hosted indepedently in the same server. This enables you to optimize some of the model component's execution performance (using TensorRT for example), while orchestrating the end-to-end inference flow with a comfortable Python interface.

<div class="alert alert-warning">
<b>Warning</b>: You should run this notebook on a SageMaker Notebook Instance with access to the same GPU as the instance you will deploy your model to (g4dn is the one configured by default in this example). There are model optimization steps contained in this notebook that are GPU architecture-dependent.
    ⬇⬇⬇⬇⬇ change in the next cell if required
</div>

------
------

In [None]:
# change this cell if you are running this notebook in a different instance type
notebook_instance_type = 'ml.g4dn.xlarge'

### Part 1 - Installs and imports

In [None]:
!pip install nvidia-pyindex
!pip install tritonclient[http]
!pip install -U sagemaker pywidgets numpy PIL

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role

import tritonclient.http as httpclient
from tritonclient.utils import *
import time
from PIL import Image
import numpy as np

# variables
s3_client = boto3.client("s3")
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

# sagemaker variables
role = get_execution_role()
sm_client = boto3.client(service_name="sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")
sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
bucket = sagemaker_session.default_bucket()

### Part 2 - Packaging a conda environment

When using the Triton Python backend (which Business Logic Scripts run on), you can include your own environment and dependencies. The recommended way to do this is to use [conda pack](https://conda.github.io/conda-pack/) to generate a conda environment archive in `tar.gz` format, include it in your model repository, and point to it in the `config.pbtxt` file of python models that should use it, adding the snippet: 

```
parameters: {
  key: "EXECUTION_ENV_PATH",
  value: {string_value: "$$TRITON_MODEL_DIRECTORY/your_env.tar.gz"}
}

```
Let's create this file and save it to the pipeline model repo, which is our business logic "model".

In [None]:
!bash conda_dependencies.sh

In [None]:
!cp sd_env.tar.gz model_repository/pipeline

---
---

### Part 3 - Model artifact creation

One of the components of Stable Diffusion is a Variational Autoencoder (VAE); only the decoder block is used for inference. NVIDIA's example shows how to use [TensorRT](https://developer.nvidia.com/tensorrt) (an ML framework to accelerate models for inference) to compile and optimize this model, which helps decrease the end-to-end latency for each request. To make sure we use TensorRT version and dependencies that are compatible with the ones in our Triton container, we compile the model using the corresponding version of NVIDIA's PyTorch container image.

The `export.sh` file also saves the text encoder (another one of Stable Diffusion's components) in ONNX format. 

In [None]:
!docker run -it --gpus all -v ${PWD}:/mount nvcr.io/nvidia/pytorch:22.10-py3 /bin/bash /mount/export.sh --verbose | tee conversion.txt

Note the namings `model.plan` and `model.onnx` are required to be recognized by Triton native backends at startup.

In [None]:
# Place the models in the right model repositories
! mv vae.plan model_repository/vae/1/model.plan
! mv encoder.onnx model_repository/text_encoder/1/model.onnx
! rm vae.onnx

Let's take a look at our logic script. If you're not familiarized with the required script structure when using Triton's Python backend, check out the documentation [here](https://github.com/triton-inference-server/python_backend#usage).

Notice some of the required steps are run in the python script itself, and some steps are offloaded to other models deployed to native backends (TensorRT and ONNX) using `triton_python_backend_utils.InferenceRequest()`. 
<div class="alert alert-info">
💡 You might notice that some model components are downloaded on initialization. If you prefer to include these in your model deployment artifact, you can save them beforehand under the <code>model_repository/pipeline</code> directory in this example and access their path using the <code>args['model_repository']</code> that Triton passes to the <code>initialize()</code> method. An example of retrieving the path for a saved model artifact: <code>f"{args['model_repository']}/my_saved_model_dir/model.pt"</code>
</div>

In [None]:
!pygmentize model_repository/pipeline/1/model.py

----
----

### Part 4 - Local testing of Triton model repository

Now you can test the model repository and validate it is working (all models load and BLS works). Let's run the Triton docker container locally and invoke the script to check this.

In [None]:
repo_name = "model_repository"

We are running the Triton container in detached model with the `-d` flag so that it runs in the background. 

In [None]:
!docker run --gpus=all -d --shm-size=4G --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd)/$repo_name:/model_repository nvcr.io/nvidia/tritonserver:22.10-py3 tritonserver --model-repository=/model_repository --exit-on-error=false
time.sleep(90)

In [None]:
CONTAINER_ID=!docker container ls -q
FIRST_CONTAINER_ID = CONTAINER_ID[0]

In [None]:
!echo $FIRST_CONTAINER_ID

In [None]:
!docker logs $FIRST_CONTAINER_ID

<div class="alert alert-warning">
<b>Warning</b>: Rerun the cell above to check the container logs until you verify that Triton has loaded all models successfully, otherwise inference request will fail.
</div>

#### Now we will invoke the script locally

We will use Triton's HTTP client and its utility functions to send a request to `localhost:8000`, where the server is listening. We are sending text as binary data for input and receiving an array that we decode with numpy as output. Check out the code in `model_repository/pipeline/1/model.py` to understand how the input data is decoded and the output data returned, and check out more Triton Python backend [docs](https://github.com/triton-inference-server/python_backend) and [examples](https://github.com/triton-inference-server/python_backend/tree/main/examples) to understand how to handle other data types.

In [None]:
client = httpclient.InferenceServerClient(url="localhost:8000")

prompt = "Pikachu in a detective trench coat, photorealistic, nikon"
text_obj = np.array([prompt], dtype="object").reshape((-1, 1))

input_text = httpclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))

input_text.set_data_from_numpy(text_obj)

output_img = httpclient.InferRequestedOutput("generated_image")

start = time.time()
query_response = client.infer(model_name="pipeline", inputs=[input_text], outputs=[output_img])
print(f"took {time.time()-start} seconds")

image = query_response.as_numpy("generated_image")
im = Image.fromarray(np.squeeze(image))
im.save("generated_image.jpg")

Let's stop the container that is running locally so we don't take up notebook resources.

In [None]:
!docker kill $FIRST_CONTAINER_ID

----
----
### Part 5 - Deploy to SageMaker Real-Time Endpoint

We first package our Triton model repository in a `tar.gz` file.

In [None]:
!rm -rf `find -type d -name .ipynb_checkpoints`

In [None]:
model_file_name = "stable-diff-bls.tar.gz"
prefix = "stable-diffusion-bls"
!tar -C model_repository/ -czf $model_file_name .
model_data_url = sagemaker_session.upload_data(path=model_file_name, key_prefix=prefix)

Check out the content of tar.gz file, make sure all folders are on the root directory of file.

In [None]:
!tar -tf $model_file_name

Get the correct URI for the Triton SageMaker container image. Check out all the available Deep Learning Container images that AWS maintains [here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md). 

In [None]:
# account mapping for SageMaker Triton Image
account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}

region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
mme_triton_image_uri = (
    "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.10-py3".format(
        account_id=account_id_map[region], region=region, base=base
    )
)

Create a SageMaker Model definition.
<div class="alert alert-info">
💡 The next two cells are very important. To make sure that the 2 model components (text encoder and VAE) called by the BLS are loaded at endpoint startup before we ever call the pipeline, we use the "SAGEMAKER_TRITON_LOG_INFO" environment variable. This should be a boolean var meant to define if verbose logs are emitted by Triton or not, but since it is appended to the end of the Triton Server launch command that runs at startup, we can add --load-model=model-name calls in front of the boolean to preload both models. 
</div>

In [None]:
preload_model_argument = "false --load-model=text_encoder --load-model=vae"

In [None]:
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": mme_triton_image_uri,
    "ModelDataUrl": model_data_url,
    "Environment": {
        "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "pipeline",
        "SAGEMAKER_TRITON_LOG_INFO": preload_model_argument,
    },
}

In [None]:
sm_model_name = f"{prefix}-mdl-{ts}"

create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Create a SageMaker endpoint configuration.

In [None]:
endpoint_config_name = f"{prefix}-epc-{ts}"

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": notebook_instance_type,
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Create the endpoint, and wait for it to be up and running.

In [None]:
endpoint_name = f"{prefix}-ep-{ts}"

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
sm_clientsm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

#### Invoke model 

In [None]:
prompt = "Smiling person"
inputs = []
outputs = []

text_obj = np.array([prompt], dtype="object").reshape((-1, 1))

inputs.append(httpclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)))
inputs[0].set_data_from_numpy(text_obj)


outputs.append(httpclient.InferRequestedOutput("generated_image"))

Since we are using the SageMaker Runtime client to send an HTTP request to the endpoint now, we use Triton's `generate_request_body` method to create the right [request format](https://github.com/triton-inference-server/server/tree/main/docs/protocol) for us.

In [None]:
request_body, header_length = httpclient.InferenceServerClient.generate_request_body(
    inputs, outputs=outputs
)

print(request_body)

We are sending our request in binary format for lower inference. 

With the binary+json format, we have to specify the length of the request metadata in the header to allow Triton to correctly parse the binary payload. This is done using a custom Content-Type header, which is different from using an `Inference-Header-Content-Length` header on a standalone Triton server because custom headers aren’t allowed in SageMaker.

In [None]:
response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(
        header_length
    ),
    Body=request_body,
)

In [None]:
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
header_length_str = response["ContentType"][len(header_length_prefix) :]

# Read response body
result = httpclient.InferenceServerClient.parse_response_body(
    response["Body"].read(), header_length=int(header_length_str)
)

image_array = result.as_numpy("generated_image")
image = Image.fromarray(np.squeeze(image_array))

In [None]:
image

----
----
### Part 6 - Clean up

In [None]:
sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=sm_model_name)

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/sagemaker-triton|business_logic_scripting|stable_diffusion|sm-triton-bls-stablediff.ipynb)
