{
"cells": [
{
"cell_type": "markdown",
"id": "dab554c6",
"metadata": {},
"source": [
"# Deploy Stable Diffusion on a SageMaker GPU Multi-Model Endpoint with LMI Containers and AiTemplate"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n",
"\n",
"\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "5640df9b",
"metadata": {},
"source": [
"**In this notebook we will host Stable Diffusion SageMaker using LMI containers**\n",
"\n",
"In this notebook, we explore how to host a large language model on SageMaker using the [Large Model Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-inference.html) container that is optimized for hosting large models using DJLServing. DJLServing is a high-performance universal model serving solution powered by the Deep Java Library (DJL) that is programming language agnostic. To learn more about DJL and DJLServing, you can refer to our recent [blog post](https://aws.amazon.com/blogs/machine-learning/deploy-large-models-on-amazon-sagemaker-using-djlserving-and-deepspeed-model-parallel-inference/).\n",
"\n",
"Language models have recently exploded in both size and popularity. In 2018, BERT-large entered the scene and, with its 340M parameters and novel transformer architecture, set the standard on NLP task accuracy. Within just a few years, state-of-the-art NLP model size has grown by more than 500x with models such as OpenAI\u2019s 175 billion parameter GPT-3 and similarly sized open source Bloom 176B raising the bar on NLP accuracy. This increase in the number of parameters is driven by the simple and empirically-demonstrated positive relationship between model size and accuracy: more is better. With easy access from models zoos such as Hugging Face and improved accuracy in NLP tasks such as classification and text generation, practitioners are increasingly reaching for these large models. However, deploying them can be a challenge because of their size.\n",
"\n",
"\n",
"This notebook was tested on a `ml.g5.2xlarge` instance \n",
"\n",
"\n",
"Model license: By using this model, please review and agree to the https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL"
]
},
{
"cell_type": "markdown",
"id": "fdbf35ff",
"metadata": {},
"source": [
"## Create a SageMaker Model for Deployment\n",
"As a first step, we'll import the relevant libraries and configure several global variables such as the hosting image that will be used nd the S3 location of our model artifacts"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69df6cd4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%pip install -Uq sagemaker"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39224b86-002c-4713-82b6-9df45ba282b0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import sagemaker\n",
"from sagemaker.model import Model\n",
"from sagemaker import serializers, deserializers\n",
"from sagemaker import image_uris\n",
"import boto3\n",
"import os\n",
"import time\n",
"import json\n",
"import jinja2\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44c48876",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import boto3\n",
"import sagemaker\n",
"from sagemaker import get_execution_role\n",
"\n",
"\n",
"import time\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"# variables\n",
"s3_client = boto3.client(\"s3\")\n",
"ts = time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b62b7e6b-f21f-4941-b779-be1c06e7413f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"role = sagemaker.get_execution_role() # execution role for the endpoint\n",
"sess = (\n",
" sagemaker.session.Session()\n",
") # sagemaker session for interacting with different AWS APIs\n",
"bucket = sess.default_bucket() # bucket to house artifacts\n",
"model_bucket = sess.default_bucket() # bucket to house artifacts\n",
"\n",
"region = sess._region_name\n",
"account_id = sess.account_id()\n",
"\n",
"s3_client = boto3.client(\"s3\")\n",
"sm_client = boto3.client(\"sagemaker\")\n",
"runtime_sm_client = boto3.client(\"sagemaker-runtime\")\n",
"\n",
"s3_code_prefix = \"stablediffusion/aitemplate/code_sd_g5\" # folder within bucket where code artifact will go\n",
"\n",
"s3_model_prefix = (\n",
" f\"s3://sagemaker-examples-files-prod-{region}/models/aitemplate_compiled/g5hw/\"\n",
")\n",
"\n",
"jinja_env = (\n",
" jinja2.Environment()\n",
") # jinja environment to generate model configuration templates\n",
"\n",
"print(s3_model_prefix)"
]
},
{
"cell_type": "markdown",
"id": "95d58e0e",
"metadata": {},
"source": [
"### Part 2 - Create the model.py file \n",
"\n",
"This file is the custom inference script for generating images. The model weights have been compiled for specific Hardware based on the below link\n",
"\n",
"https://github.com/facebookincubator/AITemplate/blob/main/examples/05_stable_diffusion/scripts/demo.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0de01ba-0b5c-4c64-a97e-959ed31ab291",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!mkdir -p code_sd\n",
"!mkdir -p code_sd/src"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "886773c8-4c4b-4381-91c8-9b6871ffe06c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%%writefile code_sd/src/__init__.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8f25354-83d8-423c-a2d2-97c59309115f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# we plug in the appropriate model location into our `serving.properties` file based on the region in which this notebook is running\n",
"template = jinja_env.from_string(Path(\"jinja_templates/serving.template\").open().read())\n",
"Path(\"code_sd/serving.properties\").open(\"w\").write(\n",
" template.render(s3url=s3_model_prefix)\n",
")\n",
"!pygmentize code_sd/serving.properties | cat -n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ad8a70a-0b7c-4cfc-937f-4e815c5583ba",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%%writefile code_sd/src/pipeline_stable_diffusion_ait.py\n",
"# Copyright (c) Meta Platforms, Inc. and affiliates.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"#\n",
"import inspect\n",
"\n",
"import os\n",
"import warnings\n",
"from typing import List, Optional, Union\n",
"\n",
"import torch\n",
"from aitemplate.compiler import Model\n",
"\n",
"from diffusers import (\n",
" AutoencoderKL,\n",
" DDIMScheduler,\n",
" DPMSolverMultistepScheduler,\n",
" EulerAncestralDiscreteScheduler,\n",
" EulerDiscreteScheduler,\n",
" LMSDiscreteScheduler,\n",
" PNDMScheduler,\n",
" StableDiffusionPipeline,\n",
" UNet2DConditionModel,\n",
")\n",
"\n",
"from diffusers.pipelines.stable_diffusion import (\n",
" StableDiffusionPipelineOutput,\n",
" StableDiffusionSafetyChecker,\n",
")\n",
"\n",
"from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
"import logging\n",
"\n",
"\n",
"logger = logging.getLogger(__name__)\n",
"logger.setLevel(logging.DEBUG)\n",
"\n",
"\n",
"class StableDiffusionAITPipeline(StableDiffusionPipeline):\n",
" r\"\"\"\n",
" Pipeline for text-to-image generation using Stable Diffusion.\n",
" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n",
" library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n",
" Args:\n",
" vae ([`AutoencoderKL`]):\n",
" Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n",
" text_encoder ([`CLIPTextModel`]):\n",
" Frozen text-encoder. Stable Diffusion uses the text portion of\n",
" [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n",
" the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n",
" tokenizer (`CLIPTokenizer`):\n",
" Tokenizer of class\n",
" [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n",
" unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n",
" scheduler ([`SchedulerMixin`]):\n",
" A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of\n",
" [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n",
" safety_checker ([`StableDiffusionSafetyChecker`]):\n",
" Classification module that estimates whether generated images could be considered offsensive or harmful.\n",
" Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.\n",
" feature_extractor ([`CLIPFeatureExtractor`]):\n",
" Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" vae: AutoencoderKL,\n",
" text_encoder: CLIPTextModel,\n",
" tokenizer: CLIPTokenizer,\n",
" unet: UNet2DConditionModel,\n",
" scheduler: Union[\n",
" DDIMScheduler,\n",
" PNDMScheduler,\n",
" LMSDiscreteScheduler,\n",
" EulerDiscreteScheduler,\n",
" EulerAncestralDiscreteScheduler,\n",
" DPMSolverMultistepScheduler,\n",
" ],\n",
" safety_checker: StableDiffusionSafetyChecker,\n",
" feature_extractor: CLIPFeatureExtractor,\n",
" requires_safety_checker: bool = True,\n",
" ):\n",
" super().__init__(\n",
" vae=vae,\n",
" text_encoder=text_encoder,\n",
" tokenizer=tokenizer,\n",
" unet=unet,\n",
" scheduler=scheduler,\n",
" safety_checker=safety_checker,\n",
" feature_extractor=feature_extractor,\n",
" requires_safety_checker=requires_safety_checker,\n",
" )\n",
"\n",
" workdir = os.getenv(\"AIT_MODEL_PATH\")\n",
" logger.info(f\"StableDiffusionAITPipeline::init:workdir={workdir}\")\n",
"\n",
" self.clip_ait_exe = self.init_ait_module(\n",
" model_name=\"CLIPTextModel\", workdir=workdir\n",
" )\n",
" self.unet_ait_exe = self.init_ait_module(\n",
" model_name=\"UNet2DConditionModel\", workdir=workdir\n",
" )\n",
" self.vae_ait_exe = self.init_ait_module(\n",
" model_name=\"AutoencoderKL\", workdir=workdir\n",
" )\n",
"\n",
" def init_ait_module(\n",
" self,\n",
" model_name,\n",
" workdir,\n",
" ):\n",
" mod = Model(os.path.join(workdir, model_name, \"test.so\"))\n",
" return mod\n",
"\n",
" def unet_inference(self, latent_model_input, timesteps, encoder_hidden_states):\n",
" exe_module = self.unet_ait_exe\n",
" timesteps_pt = timesteps.expand(latent_model_input.shape[0])\n",
" inputs = {\n",
" \"input0\": latent_model_input.permute((0, 2, 3, 1))\n",
" .contiguous()\n",
" .cuda()\n",
" .half(),\n",
" \"input1\": timesteps_pt.cuda().half(),\n",
" \"input2\": encoder_hidden_states.cuda().half(),\n",
" }\n",
" ys = []\n",
" num_outputs = len(exe_module.get_output_name_to_index_map())\n",
" for i in range(num_outputs):\n",
" shape = exe_module.get_output_maximum_shape(i)\n",
" ys.append(torch.empty(shape).cuda().half())\n",
" exe_module.run_with_tensors(inputs, ys, graph_mode=False)\n",
" noise_pred = ys[0].permute((0, 3, 1, 2)).float()\n",
" return noise_pred\n",
"\n",
" def clip_inference(self, input_ids, seqlen=64):\n",
" exe_module = self.clip_ait_exe\n",
" bs = input_ids.shape[0]\n",
" position_ids = torch.arange(seqlen).expand((bs, -1)).cuda()\n",
" inputs = {\n",
" \"input0\": input_ids,\n",
" \"input1\": position_ids,\n",
" }\n",
" ys = []\n",
" num_outputs = len(exe_module.get_output_name_to_index_map())\n",
" for i in range(num_outputs):\n",
" shape = exe_module.get_output_maximum_shape(i)\n",
" ys.append(torch.empty(shape).cuda().half())\n",
" exe_module.run_with_tensors(inputs, ys, graph_mode=False)\n",
" return ys[0].float()\n",
"\n",
" def vae_inference(self, vae_input):\n",
" exe_module = self.vae_ait_exe\n",
" inputs = [torch.permute(vae_input, (0, 2, 3, 1)).contiguous().cuda().half()]\n",
" ys = []\n",
" num_outputs = len(exe_module.get_output_name_to_index_map())\n",
" for i in range(num_outputs):\n",
" shape = exe_module.get_output_maximum_shape(i)\n",
" ys.append(torch.empty(shape).cuda().half())\n",
" exe_module.run_with_tensors(inputs, ys, graph_mode=False)\n",
" vae_out = ys[0].permute((0, 3, 1, 2)).float()\n",
" return vae_out\n",
"\n",
" @torch.no_grad()\n",
" def __call__(\n",
" self,\n",
" prompt: Union[str, List[str]],\n",
" height: Optional[int] = 512,\n",
" width: Optional[int] = 512,\n",
" num_inference_steps: Optional[int] = 50,\n",
" guidance_scale: Optional[float] = 7.5,\n",
" negative_prompt: Optional[Union[str, List[str]]] = None,\n",
" eta: Optional[float] = 0.0,\n",
" generator: Optional[torch.Generator] = None,\n",
" latents: Optional[torch.FloatTensor] = None,\n",
" output_type: Optional[str] = \"pil\",\n",
" return_dict: bool = True,\n",
" **kwargs,\n",
" ):\n",
" r\"\"\"\n",
" Function invoked when calling the pipeline for generation.\n",
" Args:\n",
" prompt (`str` or `List[str]`):\n",
" The prompt or prompts to guide the image generation.\n",
" height (`int`, *optional*, defaults to 512):\n",
" The height in pixels of the generated image.\n",
" width (`int`, *optional*, defaults to 512):\n",
" The width in pixels of the generated image.\n",
" num_inference_steps (`int`, *optional*, defaults to 50):\n",
" The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n",
" expense of slower inference.\n",
" guidance_scale (`float`, *optional*, defaults to 7.5):\n",
" Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).\n",
" `guidance_scale` is defined as `w` of equation 2. of [Imagen\n",
" Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >\n",
" 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n",
" usually at the expense of lower image quality.\n",
" negative_prompt (`str` or `List[str]`, *optional*):\n",
" The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored\n",
" if `guidance_scale` is less than `1`).\n",
" eta (`float`, *optional*, defaults to 0.0):\n",
" Corresponds to parameter eta (\u03b7) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to\n",
" [`schedulers.DDIMScheduler`], will be ignored for others.\n",
" generator (`torch.Generator`, *optional*):\n",
" A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation\n",
" deterministic.\n",
" latents (`torch.FloatTensor`, *optional*):\n",
" Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n",
" generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n",
" tensor will ge generated by sampling using the supplied random `generator`.\n",
" output_type (`str`, *optional*, defaults to `\"pil\"`):\n",
" The output format of the generate image. Choose between\n",
" [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n",
" return_dict (`bool`, *optional*, defaults to `True`):\n",
" Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n",
" plain tuple.\n",
" Returns:\n",
" [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n",
" [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n",
" When returning a tuple, the first element is a list with the generated images, and the second element is a\n",
" list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n",
" (nsfw) content, according to the `safety_checker`.\n",
" \"\"\"\n",
"\n",
" if \"torch_device\" in kwargs:\n",
" device = kwargs.pop(\"torch_device\")\n",
" warnings.warn(\n",
" \"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0.\"\n",
" \" Consider using `pipe.to(torch_device)` instead.\"\n",
" )\n",
"\n",
" # Set device as before (to be removed in 0.3.0)\n",
" if device is None:\n",
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
" self.to(device)\n",
"\n",
" if isinstance(prompt, str):\n",
" batch_size = 1\n",
" elif isinstance(prompt, list):\n",
" batch_size = len(prompt)\n",
" else:\n",
" raise ValueError(\n",
" f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\"\n",
" )\n",
"\n",
" if height % 8 != 0 or width % 8 != 0:\n",
" raise ValueError(\n",
" f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\"\n",
" )\n",
"\n",
" # get prompt text embeddings\n",
" text_input = self.tokenizer(\n",
" prompt,\n",
" padding=\"max_length\",\n",
" max_length=64, # self.tokenizer.model_max_length,\n",
" truncation=True,\n",
" return_tensors=\"pt\",\n",
" )\n",
" text_embeddings = self.clip_inference(text_input.input_ids.to(self.device))\n",
"\n",
" # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n",
" # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n",
" # corresponds to doing no classifier free guidance.\n",
" do_classifier_free_guidance = guidance_scale > 1.0\n",
" # get unconditional embeddings for classifier free guidance\n",
" if do_classifier_free_guidance:\n",
" uncond_tokens: List[str]\n",
" max_length = text_input.input_ids.shape[-1]\n",
" if negative_prompt is None:\n",
" uncond_tokens = [\"\"] * batch_size\n",
" elif type(prompt) is not type(negative_prompt):\n",
" raise TypeError(\n",
" f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n",
" f\" {type(prompt)}.\"\n",
" )\n",
" elif isinstance(negative_prompt, str):\n",
" uncond_tokens = [negative_prompt]\n",
" elif batch_size != len(negative_prompt):\n",
" raise ValueError(\n",
" f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n",
" f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n",
" \" the batch size of `prompt`.\"\n",
" )\n",
" else:\n",
" uncond_tokens = negative_prompt\n",
" uncond_input = self.tokenizer(\n",
" uncond_tokens,\n",
" padding=\"max_length\",\n",
" max_length=max_length,\n",
" truncation=True,\n",
" return_tensors=\"pt\",\n",
" )\n",
" uncond_embeddings = self.clip_inference(\n",
" uncond_input.input_ids.to(self.device)\n",
" )\n",
"\n",
" # For classifier free guidance, we need to do two forward passes.\n",
" # Here we concatenate the unconditional and text embeddings into a single batch\n",
" # to avoid doing two forward passes\n",
" text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n",
"\n",
" # get the initial random noise unless the user supplied it\n",
"\n",
" # Unlike in other pipelines, latents need to be generated in the target device\n",
" # for 1-to-1 results reproducibility with the CompVis implementation.\n",
" # However this currently doesn't work in `mps`.\n",
" latents_device = \"cpu\" if self.device.type == \"mps\" else self.device\n",
" latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)\n",
" if latents is None:\n",
" latents = torch.randn(\n",
" latents_shape,\n",
" generator=generator,\n",
" device=latents_device,\n",
" )\n",
" else:\n",
" if latents.shape != latents_shape:\n",
" raise ValueError(\n",
" f\"Unexpected latents shape, got {latents.shape}, expected {latents_shape}\"\n",
" )\n",
" latents = latents.to(self.device)\n",
"\n",
" # set timesteps\n",
" accepts_offset = \"offset\" in set(\n",
" inspect.signature(self.scheduler.set_timesteps).parameters.keys()\n",
" )\n",
" extra_set_kwargs = {}\n",
" if accepts_offset:\n",
" extra_set_kwargs[\"offset\"] = 1\n",
"\n",
" self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)\n",
"\n",
" latents = latents * self.scheduler.init_noise_sigma\n",
"\n",
" # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n",
" # eta (\u03b7) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n",
" # eta corresponds to \u03b7 in DDIM paper: https://arxiv.org/abs/2010.02502\n",
" # and should be between [0, 1]\n",
" accepts_eta = \"eta\" in set(\n",
" inspect.signature(self.scheduler.step).parameters.keys()\n",
" )\n",
" extra_step_kwargs = {}\n",
" if accepts_eta:\n",
" extra_step_kwargs[\"eta\"] = eta\n",
" # check if the scheduler accepts generator\n",
" accepts_generator = \"generator\" in set(\n",
" inspect.signature(self.scheduler.step).parameters.keys()\n",
" )\n",
" if accepts_generator:\n",
" extra_step_kwargs[\"generator\"] = generator\n",
"\n",
" for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):\n",
" # expand the latents if we are doing classifier free guidance\n",
" latent_model_input = (\n",
" torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n",
" )\n",
" latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n",
"\n",
" if isinstance(self.scheduler, LMSDiscreteScheduler):\n",
" sigma = self.scheduler.sigmas[i]\n",
" # the model input needs to be scaled to match the continuous ODE formulation in K-LMS\n",
" latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)\n",
"\n",
" # predict the noise residual\n",
" noise_pred = self.unet_inference(\n",
" latent_model_input, t, encoder_hidden_states=text_embeddings\n",
" )\n",
"\n",
" # perform guidance\n",
" if do_classifier_free_guidance:\n",
" noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
" noise_pred = noise_pred_uncond + guidance_scale * (\n",
" noise_pred_text - noise_pred_uncond\n",
" )\n",
"\n",
" # compute the previous noisy sample x_t -> x_t-1\n",
" if isinstance(self.scheduler, LMSDiscreteScheduler):\n",
" latents = self.scheduler.step(\n",
" noise_pred, i, latents, **extra_step_kwargs\n",
" ).prev_sample\n",
" else:\n",
" latents = self.scheduler.step(\n",
" noise_pred, t, latents, **extra_step_kwargs\n",
" ).prev_sample\n",
"\n",
" # scale and decode the image latents with vae\n",
" latents = 1 / 0.18215 * latents\n",
" image = self.vae_inference(latents)\n",
"\n",
" image = (image / 2 + 0.5).clamp(0, 1)\n",
" image = image.cpu().permute(0, 2, 3, 1).numpy()\n",
"\n",
" # run safety checker\n",
" if self.safety_checker is not None:\n",
" safety_checker_input = self.feature_extractor(\n",
" self.numpy_to_pil(image), return_tensors=\"pt\"\n",
" ).to(self.device)\n",
" image, has_nsfw_concept = self.safety_checker(\n",
" images=image, clip_input=safety_checker_input.pixel_values\n",
" )\n",
" else:\n",
" has_nsfw_concept = None\n",
"\n",
" if output_type == \"pil\":\n",
" image = self.numpy_to_pil(image)\n",
"\n",
" if not return_dict:\n",
" return (image, has_nsfw_concept)\n",
"\n",
" return StableDiffusionPipelineOutput(\n",
" images=image, nsfw_content_detected=has_nsfw_concept\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82382b39-fd91-4f71-8248-79bd3067d7fe",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%%writefile code_sd/model.py\n",
"from djl_python import Input, Output\n",
"import deepspeed\n",
"import torch\n",
"import logging\n",
"import math\n",
"import os\n",
"import sys\n",
"from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\n",
"import logging\n",
"\n",
"from io import BytesIO\n",
"\n",
"import torch\n",
"import aitemplate\n",
"\n",
"# from torch.utils.dlpack import to_dlpack, from_dlpack\n",
"\n",
"# from aitemplate.testing.benchmark_pt import benchmark_torch_function\n",
"# from aitemplate.utils.import_path import import_parent\n",
"from diffusers import DPMSolverMultistepScheduler\n",
"from diffusers import (\n",
" AutoencoderKL,\n",
" DDIMScheduler,\n",
" DPMSolverMultistepScheduler,\n",
" EulerAncestralDiscreteScheduler,\n",
" EulerDiscreteScheduler,\n",
" LMSDiscreteScheduler,\n",
" PNDMScheduler,\n",
" StableDiffusionPipeline,\n",
" UNet2DConditionModel,\n",
")\n",
"\n",
"import logging\n",
"\n",
"\n",
"logger = logging.getLogger(__name__)\n",
"logger.setLevel(logging.DEBUG)\n",
"\n",
"\n",
"model = None\n",
"\n",
"\n",
"def load_model(properties):\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
" logger.info(f\"Model:init:args={properties}:\")\n",
" logger.info(f\"Model device={device}:\")\n",
"\n",
" tensor_parallel = properties[\"tensor_parallel_degree\"]\n",
" model_dir = properties[\"model_dir\"]\n",
" if \"model_id\" in properties:\n",
" model_dir = properties[\"model_id\"]\n",
" logging.info(f\"Loading model in {model_dir}: was set in model_id:\")\n",
" sys.path.append(model_dir)\n",
"\n",
" for f in os.listdir(model_dir):\n",
" logger.info(f\" model_dir={model_dir}:: file={f}::\")\n",
"\n",
" # from src import pipeline_stable_diffusion_ait\n",
" from src.pipeline_stable_diffusion_ait import StableDiffusionAITPipeline\n",
"\n",
" os.environ[\n",
" \"AIT_MODEL_PATH\"\n",
" ] = f\"{model_dir}/\" # -- this is needed to load the AIT model in the CLIPText etc\n",
" # - check the notes for why this needs to be done\n",
" pipe = StableDiffusionAITPipeline.from_pretrained(\n",
" f\"{model_dir}/diffusers-pipeline/stabilityai/stable-diffusion-v2\",\n",
" scheduler=EulerDiscreteScheduler.from_pretrained( # scheduler/scheduler_config.json\n",
" f\"{model_dir}/diffusers-pipeline/stabilityai/stable-diffusion-v2\",\n",
" subfolder=\"scheduler\",\n",
" ),\n",
" revision=\"fp16\",\n",
" torch_dtype=torch.float16,\n",
" )\n",
" pipe.to(device)\n",
"\n",
" logger.info(f\" Pipe Model created successfully:\")\n",
"\n",
" # pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
"\n",
" logger.info(f\" Pipe Scheduler config: created successfully:\")\n",
" return pipe\n",
"\n",
"\n",
"def run_inference(model, inputs):\n",
" logger.info(f\"Prediction:request_list={inputs}:\")\n",
"\n",
" input_text = inputs[\"prompt\"]\n",
" with torch.autocast(\"cuda\"):\n",
" image = model(input_text, 512, 512).images[0]\n",
"\n",
" # image.save(\"example_ait.png\")\n",
" buf = BytesIO()\n",
" image.save(buf, format=\"PNG\")\n",
" byte_img = buf.getvalue()\n",
" return byte_img\n",
" # return Output().add(byte_img).add_property(\"content-type\", \"image/png\")\n",
"\n",
"\n",
"def handle(inputs: Input):\n",
" global model\n",
" if not model:\n",
" model = load_model(inputs.get_properties())\n",
"\n",
" if inputs.is_empty():\n",
" # Model server makes an empty call to warmup the model on startup\n",
" return None\n",
" data = inputs.get_as_json()\n",
"\n",
" input_sentences = data[\"inputs\"]\n",
"\n",
" byte_img = run_inference(model, input_sentences)\n",
" # return Output().add_as_json(result)\n",
" return Output().add(byte_img).add_property(\"content-type\", \"image/png\")"
]
},
{
"cell_type": "markdown",
"id": "dcec9a66-f43a-4bc6-9c38-8b695215101b",
"metadata": {},
"source": [
"There are a few options specified here. Lets go through them in turn
\n",
"1. `engine` - specifies the engine that will be used for this workload. In this case we'll be hosting a model using the [DJL Python Engine](https://github.com/deepjavalibrary/djl-serving/tree/master/engines/python)\n",
"2. `option.entryPoint` - specifies the entrypoint code that will be used to host the model. djl_python.huggingface refers to the `huggingface.py` module from [djl_python repo](https://github.com/deepjavalibrary/djl-serving/tree/master/engines/python/setup/djl_python). \n",
"3. `option.s3url` - specifies the location of the model files. Alternativelly an `option.model_id` option can be used instead to specifiy a model from Hugging Face Hub (e.g. `EleutherAI/gpt-j-6B`) and the model will be automatically downloaded from the Hub. The s3url approach is recommended as it allows you to host the model artifact within your own environment and enables faster deployments by utilizing optimized approach within the DJL inference container to transfer the model from S3 into the hosting instance \n",
"4. `option.task` - This is specific to the `huggingface.py` inference handler and specifies for which task this model will be used\n",
"5. `option.device_map` - Enables layer-wise model partitioning through [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/usage_guides/big_modeling#designing-a-device-map). With `option.device_map=auto`, Accelerate will determine where to put each **layer** to maximize the use of your fastest devices (GPUs) and offload the rest on the CPU, or even the hard drive if you don\u2019t have enough GPU RAM (or CPU RAM). Even if the model is split across several devices, it will run as you would normally expect.\n",
"6. `option.load_in_8bit` - Quantizes the model weights to int8 thereby greatly reducing the memory footprint of the model from the initial FP32. See this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration) from Hugging Face for additional information \n",
"\n",
"For more information on the available options, please refer to the [SageMaker Large Model Inference Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-configuration.html)\n",
"\n",
"This diagram shows howor how we can shard models using Model Parallelism, we can use Tensor Parallelisim as well. Like within Hugging Face Transformers to enable Large Language Model hosting these are exposed through the `device_map` and `load_in_8bit` parameters which enable sharding and shrinking of the model. The sharding approach taken here is layer wise as individual model layers are placed onto different GPU devices and data flows sequentially from the input to the final output layer as illustated below
\n",
"\n",
"\n",
"Even though in this example the model will be running on a single GPU and will not be sharded, this parameter would automatically apply sharding as we scale to larger models on multi-GPU instances."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9ccd880-b69a-4075-9716-7292121b9135",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%%writefile code_sd/requirements.txt\n",
"boto3\n",
"awscli\n",
"# git+https://github.com/facebookincubator/AITemplate@0.2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a45a759-5bca-4eb8-858f-c2e139deac1e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!rm -rf `find -type d -name .ipynb_checkpoints`\n",
"!rm model.tar.gz\n",
"!tar czvf model.tar.gz code_sd"
]
},
{
"cell_type": "markdown",
"id": "9b61c912-d196-4e42-93f8-80aaa386f502",
"metadata": {},
"source": [
"### Upload the Tar file to S3 for Creation of End points"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3ea15e83-f1e6-458e-873c-8ff1f01b310e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"s3_code_artifact = sess.upload_data(\"model.tar.gz\", bucket, s3_code_prefix)\n",
"print(f\"S3 Code or Model tar ball uploaded to --- > {s3_code_artifact}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c3d52b02-92bd-48f3-9bce-32448d81ff0a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"print(f\"S3 Model Prefix where the model files are -- > {s3_model_prefix}\")\n",
"print(f\"S3 Model Bucket is -- > {model_bucket}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd528462-266e-46c1-bf65-3715d6f4d381",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# inference_image_uri = f\"{account_id}.dkr.ecr.{region}.amazonaws.com/djl-ds:latest\"\n",
"inference_image_uri = f\"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117\"\n",
"print(f\"Image going to be used is ---- > {inference_image_uri}\")"
]
},
{
"cell_type": "markdown",
"id": "9404c2a0-1195-4ad0-8ba3-0aee4c552480",
"metadata": {},
"source": [
"### Creating end point in SageMaker\n",
"