{ "cells": [ { "cell_type": "markdown", "id": "de7e5075", "metadata": {}, "source": [ "# Text to Image (Stable Diffusion)" ] }, { "cell_type": "markdown", "id": "f152b43c", "metadata": {}, "source": [ "![alt text](image-assets/StabilityAi_Logo-Coloured_On_Black-12.png \"stability.ai\")\n", "\n", "---\n", "Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker JumpStart API](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart). \n", "\n", "In this demo notebook, we demonstrate how to use the JumpStart API for Text-to-Image. Text-to-Image is the task of generating realistic images given any text input. Here, we show how to use state-of-the-art pre-trained Stable Diffusion models for generating image from text.\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "1e084fd6", "metadata": {}, "source": [ "1. [Set Up](#1.-Set-Up)\n", "3. [Retrieve JumpStart Artifacts & Deploy an Endpoint](#2.-Retrieve-JumpStart-Artifacts-&-Deploy-an-Endpoint)\n", "4. [Query endpoint and parse response](#3.-Query-endpoint-and-parse-response)\n", "5. [Advanced features](#4.-Advanced-features)\n", "6. [Clean up the endpoint](#5.-Clean-up-the-endpoint)" ] }, { "cell_type": "markdown", "id": "653f7f5e", "metadata": {}, "source": [ "Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.\n", "\n", "Please see steps in [Onboard to Amazon SageMaker Domain Using Quick setup](https://docs.aws.amazon.com/sagemaker/latest/dg/onboard-quick-start.html) to make use of SageMaker Studio." ] }, { "cell_type": "markdown", "id": "6e38978f", "metadata": {}, "source": [ "### 1. Set Up" ] }, { "cell_type": "markdown", "id": "44a2a557", "metadata": {}, "source": [ "---\n", "Before executing the notebook, there are some initial steps required for set up. This notebook requires latest version of sagemaker and ipywidgets\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "14e70228", "metadata": { "tags": [] }, "outputs": [], "source": [ "!pip install sagemaker ipywidgets --upgrade --quiet" ] }, { "cell_type": "markdown", "id": "da123874", "metadata": {}, "source": [ "#### Permissions and environment variables\n", "\n", "---\n", "To host on Amazon SageMaker, we need to set up and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook as the AWS account role with SageMaker access. \n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "70dd913a", "metadata": { "tags": [] }, "outputs": [], "source": [ "import sagemaker, boto3, json\n", "from sagemaker import get_execution_role\n", "\n", "aws_role = get_execution_role()\n", "aws_region = boto3.Session().region_name\n", "sess = sagemaker.Session()" ] }, { "cell_type": "markdown", "id": "6aa16954", "metadata": {}, "source": [ "### 2. Retrieve JumpStart Artifacts & Deploy an Endpoint\n", "\n", "***\n", "\n", "Using JumpStart, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the `deploy_image_uri`, `deploy_source_uri`, and `model_uri` for the pre-trained model. To host the pre-trained model, we create an instance of [`sagemaker.model.Model`](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html) and deploy it. This may take a few minutes.\n", "\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "d62576bf", "metadata": { "tags": [] }, "outputs": [], "source": [ "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", "from sagemaker.model import Model\n", "from sagemaker.predictor import Predictor\n", "from sagemaker.utils import name_from_base\n", "\n", "# model_version=\"*\" fetches the latest version of the model\n", "model_id, model_version = \"model-txt2img-stabilityai-stable-diffusion-v1-4\", \"*\"\n", "\n", "endpoint_name = name_from_base(f\"jumpstart-console-infer-{model_id}\")\n", "\n", "inference_instance_type = \"ml.p3.2xlarge\"\n", "\n", "# Retrieve the inference docker container uri. This is the base container image for the default model above.\n", "deploy_image_uri = image_uris.retrieve(\n", " region=None,\n", " framework=None, # automatically inferred from model_id\n", " image_scope=\"inference\",\n", " model_id=model_id,\n", " model_version=model_version,\n", " instance_type=inference_instance_type,\n", ")\n", "\n", "# Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.\n", "deploy_source_uri = script_uris.retrieve(\n", " model_id=model_id, model_version=model_version, script_scope=\"inference\"\n", ")\n", "\n", "\n", "# Retrieve the model uri. This includes the pre-trained nvidia-ssd model and parameters.\n", "model_uri = model_uris.retrieve(\n", " model_id=model_id, model_version=model_version, model_scope=\"inference\"\n", ")\n", "\n", "\n", "# Create the SageMaker model instance\n", "model = Model(\n", " image_uri=deploy_image_uri,\n", " source_dir=deploy_source_uri,\n", " model_data=model_uri,\n", " entry_point=\"inference.py\", # entry point file in source_dir and present in deploy_source_uri\n", " role=aws_role,\n", " predictor_cls=Predictor,\n", " name=endpoint_name,\n", ")\n", "\n", "# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,\n", "# for being able to run inference through the sagemaker API.\n", "model_predictor = model.deploy(\n", " initial_instance_count=1,\n", " instance_type=inference_instance_type,\n", " predictor_cls=Predictor,\n", " endpoint_name=endpoint_name,\n", ")" ] }, { "cell_type": "markdown", "id": "f5a5e8c7", "metadata": {}, "source": [ "### 3. Query endpoint and parse response\n", "\n", "---\n", "Input to the endpoint is any string of text dumped in json and encoded in `utf-8` format. Output of the endpoint is a `json` with generated text.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "637fe75d", "metadata": { "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "\n", "def query(model_predictor, text):\n", " \"\"\"Query the model predictor.\"\"\"\n", "\n", " encoded_text = json.dumps(text).encode(\"utf-8\")\n", "\n", " query_response = model_predictor.predict(\n", " encoded_text,\n", " {\n", " \"ContentType\": \"application/x-text\",\n", " \"Accept\": \"application/json\",\n", " },\n", " )\n", " return query_response\n", "\n", "\n", "def parse_response(query_response):\n", " \"\"\"Parse response and return generated image and the prompt\"\"\"\n", "\n", " response_dict = json.loads(query_response)\n", " return response_dict[\"generated_image\"], response_dict[\"prompt\"]\n", "\n", "\n", "def display_img_and_prompt(img, prmpt):\n", " \"\"\"Display hallucinated image.\"\"\"\n", " plt.figure(figsize=(12, 12))\n", " plt.imshow(np.array(img))\n", " plt.axis(\"off\")\n", " plt.title(prmpt)\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "089f1a07", "metadata": {}, "source": [ "---\n", "Below, we put in some example input text. You can put in any text and the model predicts the image corresponding to that text.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "0faa30c1", "metadata": { "pycharm": { "is_executing": true }, "tags": [] }, "outputs": [], "source": [ "text = \"cottage in impressionist style\"\n", "query_response = query(model_predictor, text)\n", "img, prmpt = parse_response(query_response)\n", "display_img_and_prompt(img, prmpt)" ] }, { "cell_type": "markdown", "id": "1edc0e96", "metadata": { "pycharm": { "is_executing": true } }, "source": [ "### 4. Advanced features\n", "\n", "***\n", "This model also supports many advanced parameters while performing inference. They include:\n", "\n", "* **prompt**: prompt to guide the image generation. Must be specified and can be a string or a list of strings.\n", "* **width**: width of the hallucinated image. If specified, it must be a positive integer divisible by 8.\n", "* **height**: height of the hallucinated image. If specified, it must be a positive integer divisible by 8.\n", "* **num_inference_steps**: Number of denoising steps during image generation. More steps lead to higher quality image. If specified, it must a positive integer.\n", "* **guidance_scale**: Higher guidance scale results in image closely related to the prompt, at the expense of image quality. If specified, it must be a float. guidance_scale<=1 is ignored.\n", "* **negative_prompt**: guide image generation against this prompt. If specified, it must be a string or a list of strings and used with guidance_scale. If guidance_scale is disabled, this is also disabled. Moreover, if prompt is a list of strings then negative_prompt must also be a list of strings. \n", "* **num_images_per_prompt**: number of images returned per prompt. If specified it must be a positive integer. \n", "* **seed**: Fix the randomized state for reproducibility. If specified, it must be an integer.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "1e4eb3c8", "metadata": { "pycharm": { "is_executing": true }, "tags": [] }, "outputs": [], "source": [ "import json\n", "\n", "payload = {\n", " \"prompt\": \"astronaut on a horse\",\n", " \"width\": 400,\n", " \"height\": 400,\n", " \"num_images_per_prompt\": 2,\n", " \"num_inference_steps\": 50,\n", " \"guidance_scale\": 7.5,\n", "}\n", "\n", "\n", "def query_endpoint_with_json_payload(model_predictor, payload):\n", " \"\"\"Query the model predictor with json payload.\"\"\"\n", "\n", " encoded_payload = json.dumps(payload).encode(\"utf-8\")\n", "\n", " query_response = model_predictor.predict(\n", " encoded_payload,\n", " {\n", " \"ContentType\": \"application/json\",\n", " \"Accept\": \"application/json\",\n", " },\n", " )\n", " return query_response\n", "\n", "\n", "def parse_response_multiple_images(query_response):\n", " \"\"\"Parse response and return generated image and the prompt\"\"\"\n", "\n", " response_dict = json.loads(query_response)\n", " return response_dict[\"generated_images\"], response_dict[\"prompt\"]\n", "\n", "\n", "query_response = query_endpoint_with_json_payload(model_predictor, payload)\n", "generated_images, prompt = parse_response_multiple_images(query_response)\n", "\n", "for img in generated_images:\n", " display_img_and_prompt(img, prompt)" ] }, { "cell_type": "markdown", "id": "870be7c0", "metadata": {}, "source": [ "### 5. Clean up the endpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "fe5599c9", "metadata": {}, "outputs": [], "source": [ "# Delete the SageMaker endpoint\n", "model_predictor.delete_model()\n", "model_predictor.delete_endpoint()" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 5 }