# Optimized Stable Diffusion Deployments

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

---


This notebook will show you how to deploy an [AITemplate](https://github.com/facebookincubator/AITemplate/tree/main/examples/05_stable_diffusion) Optimized version of Stable Diffusion which delivers 2X performance gain versus a standard version without sacrificing the quality of the generated images. 

Additionally, this notebook will demonstrate how to deploy an endpoint with pagination capabilities that would allow the API caller to display intermediate de-noising steps and reducing the initial latency to subsecond range. This enhances the end-user expereince by providing more immediate results and showing a smooth animation of the end to end image generation process. However this comes at an additional compute cost of decoding intermediate latent outputs.   

In [None]:
%pip install -Uq sagemaker

In [None]:
import sagemaker
from sagemaker.model import Model
from sagemaker import serializers, deserializers
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path

from io import BytesIO
from PIL import Image
import base64
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
import time

import matplotlib.pyplot as plt
from IPython import display
from IPython.display import clear_output
from IPython.core.display import HTML

%matplotlib inline

In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
model_bucket = sess.default_bucket()  # bucket to house artifacts
s3_code_prefix = "stable-diffusion-2/code"  # folder within bucket where code artifact will go
s3_model_prefix = "stable-diffusion-2/model"  # folder where model checkpoint will go

region = sess._region_name
account_id = sess.account_id()

In [None]:
inference_image_uri = (
    f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117"
)
print(f"Image going to be used is ---- > {inference_image_uri}")

We have provided compiled AITemplate weights for the `ml.g5` class of instances. You can compile these on your own using the instructions [here](https://github.com/facebookincubator/AITemplate/tree/main/examples/05_stable_diffusion).

In [None]:
ait_compiled_weight_uri = (
    f"s3://sagemaker-example-files-prod-{region}/models/aitemplate_compiled/g5hw/"
)
print(f"Compiled weights to be used is  ---- > {ait_compiled_weight_uri}")

In [None]:
def deploy_model(image_uri, model_data, role, endpoint_name, instance_type, env, sagemaker_session):
    """Helper function to create the SageMaker Endpoint resources and return a predictor"""

    model = Model(image_uri=image_uri, model_data=model_data, role=role, env=env)

    model.deploy(initial_instance_count=1, instance_type=instance_type, endpoint_name=endpoint_name)

    predictor = sagemaker.Predictor(
        endpoint_name=endpoint_name,
        sagemaker_session=sagemaker_session,
        serializer=serializers.JSONSerializer(),
        deserializer=deserializers.JSONDeserializer(),
    )

    return predictor

## Deploy Model
In this section we will package the model configuration and inference code and deploy it to a SageMaker Endpoint. The following are the steps to deploy the endpoint:
1. Update the `serving.properties` configuration file with the location of the compiled model artifacts. More information on the supported configurations can be found [here](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-configuration.html)
2. Package the inference code along with the configuration file into a `model.tar.gz`
3. Upload the `model.tar.gz` to an S3 bucket
4. Deploy the model using the `deploy_model` helper function

In [None]:
!sed -i 's@option.s3url=.*@option.s3url={ait_compiled_weight_uri}@g' model/serving.properties

In [None]:
!pygmentize model/serving.properties | cat -n

The inference code is contained within the [model.py](model/model.py) file in the `model` source directory. We use an environment variable `PAGINATION` to indicate whether to use the standard pipeline which will only return the final image, or a pagination based pipeline which will return intermediate results of each de-noising step. The code for each pipeline is contained within it's own python module:
- [pipeline_stable_diffusion_ait.py](model/pipeline_stable_diffusion_ait.py) - Code for the standard pipeline
- [pipeline_stable_diffusion_pagination_ait.py](model/pipeline_stable_diffusion_pagination_ait.py) - Code for the paginated pipeline

The pipelines require AITemplate to be installed in the inference container. As of 4/2023 AITemplate is not available from PyPi and must be installed by building from source code as per the instructions in the [git repo](https://github.com/facebookincubator/AITemplate). For convinience, we've included a pre-compiled python wheel `model/ait/aitemplate-0.3.dev0-py3-none-any.whl` that will be installed when the endpoint is launched

In [None]:
!tar czvf sd_model.tar.gz model/

In [None]:
sd_s3_code_artifact = sess.upload_data("sd_model.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {sd_s3_code_artifact}")

The inference code supports both paginated and non paginated responses which is controlled by the `PAGINATION` environment variable

Here we will deploy the endpoint without pagination by setting the environment variable to `false`

In [None]:
sd_endpoint_name = sagemaker.utils.name_from_base("stable-diffusion")
sd_predictor = deploy_model(
    image_uri=inference_image_uri,
    model_data=sd_s3_code_artifact,
    role=role,
    env={"PAGINATION": "false"},
    endpoint_name=sd_endpoint_name,
    instance_type="ml.g5.xlarge",
    sagemaker_session=sess,
)

In [None]:
def invoke_endpoint(predictor, payload):
    "helper function to invoke endpoint"
    result = predictor.predict(payload)
    return result

In [None]:
def decode_image(img):
    "decodes the base64 encoded image that is returned by the endpoint"
    buff = BytesIO(base64.b64decode(img.encode("utf8")))
    image = Image.open(buff)
    return image

In [None]:
prompt = """60s cartoon style photo of a Panda bear wearing underground clothes in far off galaxy warner brothers, trending pixiv fanbox, acrylic palette knife, 8k, vibrant colors, devinart, trending on artstation, low details, smooth 
"""
negative_prompt = "ugly, tiling, blurred, watermark, grainy, signature, cut off, draft, amateur, multiple,  text, poor, low, basic, worst, unprofessional"
payload = {
    "parameters": {
        "num_inference_steps": 50,
        "guidance_scale": 9,
        "negative_prompt": negative_prompt,
        "num_images_per_prompt": 1,
        # the next 2 parameters will only be utilized by the pagination enabled endpoint
        "starting_step": 0,
        "num_interim_images": 5,
    },
    "prompt": prompt,
}

In [None]:
t1 = time.perf_counter()
response = invoke_endpoint(sd_predictor, payload)
response_time = time.perf_counter() - t1
[decode_image(img) for img in response["images"]][0]

In [None]:
print(f"Response returned in {response_time:.2f}s")

In [None]:
sd_predictor.delete_endpoint()

## Enable Pagination
To enable pagination of intermediate results, we set the `PAGINATION` environment variable to `true` and redeploy the endpoint. Rather than just a single image within its response, the paginated endpoint contains 3 values in its response:
1. Batch of intermediate images encoded as base64 encoded JPEGs
2. A [safetensor](https://github.com/huggingface/safetensors) value for the last latent in the generation pipeline encoded as base64
3. The last step number in the generation pipeline

Items 2 and 3 enable the pagination. By providing a latent tensor and the step number, we can bypass the completed steps and pick up the image generation from the last completed step. Essentially after receiving the initial batch of intermediate images, we invoke the endpoint again this time providing the latent input and the step number. This process repeats until the specified number of denoising steps are completed. This allows for a next batch of images to be fetched concurrently while intermediate frames are displayed to the user.

In [None]:
sd_endpoint_name = sagemaker.utils.name_from_base("stable-diffusion")
sd_predictor = deploy_model(
    image_uri=inference_image_uri,
    model_data=sd_s3_code_artifact,
    role=role,
    env={"PAGINATION": "true"},
    endpoint_name=sd_endpoint_name,
    instance_type="ml.g5.xlarge",
    sagemaker_session=sess,
)

The function bellow encapsulates the process for querying the pagination endpoint. It provides a python iterator than we can iterate through to display the intermediate images. A background thread is used to fetch susbsequent batches of images to simulate the concurrentcy aspect 

In [None]:
def run_paginated_inference(predictor, initial_payload):
    "creates an iterator for intermediate images"

    payload = deepcopy(initial_payload)  # make a deep copy to not mutate the initial payload
    num_inference_steps = payload["parameters"]["num_inference_steps"]
    steps_completed = 0
    while steps_completed < num_inference_steps:
        if steps_completed == 0:
            payload["parameters"]["starting_step"] = 0
            result = invoke_endpoint(predictor, payload)
        else:
            while not future.done():
                time.sleep(0.1)
            result = future.result()

        images = result["images"]
        steps_completed = result["step"]

        payload["parameters"]["starting_step"] = result["step"]
        payload["parameters"]["latents"] = result["latents"]

        # use a single background thread to fetch next set images while curernt batch is displayed
        with ThreadPoolExecutor(max_workers=1) as e:
            future = e.submit(predictor.predict, payload)

        for img in images:
            yield img

In [None]:
images_it = run_paginated_inference(sd_predictor, payload)  # create the image iterator
t1 = time.perf_counter()
for n, img in enumerate(images_it):
    if n == 0:
        inital_response_time = time.perf_counter() - t1

    html = f"""<div>
      <img src="data:image/png;base64, {img}" />
    </div>"""
    display.display(HTML(html))
    clear_output(wait=True)  # comment this line to see individual outputs

In [None]:
print(f"first batch delivered in {inital_response_time} seconds")

We can see from above that the first batch was delivered in under one second. This provides a more immediate response to the user at the expense of additional compute cost of having to decode intermediate images. It also doubles the time to generate the final image. 

In [None]:
sd_predictor.delete_endpoint()

## Conclusion
In this notebook we saw how we can deploy an AITemplate optimized Stable Diffusion model which offers a 2X peformance increase without sacrificing quality of the generated image. We also saw how we can provide a beter User Experience by returning intermediate results which provides a faster initial response time and a look into the image generation process. 

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/inference|generativeai|llm-workshop|lab2-stable-diffusion|option2-aitemplate|sd_txt2img.ipynb)