{ "cells": [ { "cell_type": "markdown", "id": "b7b62d3c", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## SageMaker Stable diffusion Quick Kit \n", " [SageMaker Stable Diffusion Quick Kit](https://github.com/aws-samples/sagemaker-stablediffusion-quick-kit) is an asset to help our customers launch stable diffusion models services on Amazon Sagemaker or Amazon EKS.\n", " \n", " ![architecture](https://raw.githubusercontent.com/aws-samples/sagemaker-stablediffusion-quick-kit/main/images/architecture.png)\n", "\n", "#### Prerequisites\n", "1. Amazon Web Service account \n", "2. ml.g4dn.xlarge or ml.g5xlarge perfer to used\n", "\n", "\n", "#### Notebook Step\n", "1. Upgrage boto3, sagemaker python sdk \n", "2. Deploy AIGC inference service with SageMaker Endpoint service \n", " * config model parameter \n", " * config async inference\n", " * deploy SageMaker Endpoint\n", "3. Test inference\n", "4. SageMaker endpoint AutoScaling Config(option)\n", "5. Clear resource" ] }, { "cell_type": "markdown", "id": "7411103c", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 1. Upgrage boto3, sagemaker python sdk " ] }, { "cell_type": "code", "execution_count": null, "id": "712f4581", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "!pip install --upgrade boto3 sagemaker" ] }, { "cell_type": "code", "execution_count": null, "id": "44c2f4a6", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import boto3\n", "import sagemaker\n", "account_id = boto3.client('sts').get_caller_identity().get('Account')\n", "region_name = boto3.session.Session().region_name\n", "\n", "sagemaker_session = sagemaker.Session()\n", "bucket = sagemaker_session.default_bucket()\n", "role = sagemaker.get_execution_role()\n", "\n", "print(f'execution role: {role}')\n", "print(f'default bucket: {bucket}')" ] }, { "cell_type": "markdown", "id": "4495eb00", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 2. Deploy AIGC inference service with SageMaker Endpoint service " ] }, { "cell_type": "markdown", "id": "e20a2a90", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 2.1 setup model name and arguments \n", " * model_name: Huggingface diffusers models (not support single check point format)\n", " * model_args: diffuser StableDiffusionPipeline init arguments\n", " * framework_version: pytroch version\n", " * py_version: python: 3.8 \n", " * model_environment: inference contianer env " ] }, { "cell_type": "code", "execution_count": null, "id": "eae5b761", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "\n", "framework_version = '1.10'\n", "py_version = 'py38'\n", "\n", "model_environment = {\n", " 'SAGEMAKER_MODEL_SERVER_TIMEOUT':'600', \n", " 'SAGEMAKER_MODEL_SERVER_WORKERS': '1', \n", " 'model_name':'Linaqruf/anything-v3.0', #huggingface model name \n", " 's3_bucket': bucket\n", "}" ] }, { "cell_type": "markdown", "id": "154d4aeb", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 2.2 Create fake dummy model_data file, and create PyTorchModel for SageMaker Endpoint." ] }, { "cell_type": "code", "execution_count": null, "id": "c7e17397", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "!touch dummy\n", "!tar czvf model.tar.gz dummy sagemaker-logo-small.png\n", "assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'stablediffusion')\n", "model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'stablediffusion')\n", "!aws s3 cp model.tar.gz $assets_dir\n", "!rm -f dummy model.tar.gz" ] }, { "cell_type": "code", "execution_count": null, "id": "91b72318", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from sagemaker.pytorch.model import PyTorchModel\n", "\n", "model = PyTorchModel(\n", " name = None,\n", " model_data = model_data,\n", " entry_point = 'inference.py',\n", " source_dir = \"./code/\",\n", " role = role,\n", " framework_version = framework_version, \n", " py_version = py_version,\n", " env = model_environment\n", ")" ] }, { "cell_type": "markdown", "id": "1f140c04", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 2.3 Config async inference output , setup config instance_type and name\n" ] }, { "cell_type": "code", "execution_count": null, "id": "74b24df7", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from sagemaker.async_inference import AsyncInferenceConfig\n", "import uuid\n", "\n", "endpoint_name = f'AIGC-Quick-Kit-{str(uuid.uuid4())}'\n", "instance_type = 'ml.g4dn.xlarge'\n", "instance_count = 1\n", "async_config = AsyncInferenceConfig(output_path='s3://{0}/{1}/asyncinvoke/out/'.format(bucket, 'stablediffusion'))\n", "\n", "print(f'endpoint_name: {endpoint_name}')" ] }, { "cell_type": "markdown", "id": "dc298014", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### 2.4 Deploy SageMaker Endpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "cf86cd3e", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from sagemaker.serializers import JSONSerializer\n", "from sagemaker.deserializers import JSONDeserializer\n", "\n", "\n", "async_predictor = model.deploy(\n", " endpoint_name = endpoint_name,\n", " instance_type = instance_type, \n", " initial_instance_count = instance_count,\n", " async_inference_config = async_config,\n", " serializer = JSONSerializer(),\n", " deserializer = JSONDeserializer()\n", ")" ] }, { "cell_type": "markdown", "id": "70b79c21", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "\n", "#### 2.5 Create async inference invoke help function \n", " * get_bucket_and_key, read s3 object\n", " * draw_image, download image from s3 and draw it in notebook\n", " * async_predict_fn \n" ] }, { "cell_type": "code", "execution_count": null, "id": "b9220ca7", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import json\n", "import io\n", "from PIL import Image\n", "import traceback\n", "import time\n", "from sagemaker.async_inference.waiter_config import WaiterConfig\n", "\n", "\n", "s3_resource = boto3.resource('s3')\n", "\n", "def get_bucket_and_key(s3uri):\n", " pos = s3uri.find('/', 5)\n", " bucket = s3uri[5 : pos]\n", " key = s3uri[pos + 1 : ]\n", " return bucket, key\n", "\n", "def draw_image(response):\n", " try:\n", " bucket, key = get_bucket_and_key(response.output_path)\n", " obj = s3_resource.Object(bucket, key)\n", " body = obj.get()['Body'].read().decode('utf-8') \n", " predictions = json.loads(body)['result']\n", " print(predictions)\n", " for prediction in predictions:\n", " bucket, key = get_bucket_and_key(prediction)\n", " obj = s3_resource.Object(bucket, key)\n", " bytes = obj.get()['Body'].read()\n", " image = Image.open(io.BytesIO(bytes))\n", " image.show()\n", " except Exception as e:\n", " traceback.print_exc()\n", " print(e)\n", "\n", "\n", "def async_predict_fn(predictor,inputs):\n", " response = predictor.predict_async(inputs)\n", " \n", " print(f\"Response object: {response}\")\n", " print(f\"Response output path: {response.output_path}\")\n", " print(\"Start Polling to get response:\")\n", " \n", " start = time.time()\n", " config = WaiterConfig(\n", " max_attempts=100, # number of attempts\n", " delay=10 # time in seconds to wait between attempts\n", " )\n", "\n", " response.get_result(config)\n", " draw_image(response)\n", "\n", " print(f\"Time taken: {time.time() - start}s\")" ] }, { "cell_type": "markdown", "id": "0dd1556c", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 3. Testing\n", " 3.1 txt2img inference" ] }, { "cell_type": "code", "execution_count": null, "id": "eefbe9f2", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#AIGC Quick Kit txt2img\n", "inputs_txt2img = {\n", " \"prompt\": \"a photo of an astronaut riding a horse on mars\",\n", " \"negative_prompt\":\"\",\n", " \"steps\":20,\n", " \"sampler\":\"euler_a\",\n", " \"seed\": 31252362,\n", " \"height\": 512, \n", " \"width\": 512,\n", " \"count\":2\n", "\n", "}\n", "\n", "\n", "async_predict_fn(async_predictor,inputs_txt2img)\n", "\n" ] }, { "cell_type": "markdown", "id": "3fe1aab9", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ " 3.2 img2img inference\n", " \n", " * origin image :![](https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg)" ] }, { "cell_type": "code", "execution_count": null, "id": "00f94fc0", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#AIGC Quick Kit img2img\n", "\n", "inputs_img2img = {\n", " \"prompt\": \"A fantasy landscape, trending on artstation\",\n", " \"negative_prompt\":\"\",\n", " \"steps\":20,\n", " \"sampler\":\"euler_a\",\n", " \"seed\":43768,\n", " \"height\": 512, \n", " \"width\": 512,\n", " \"count\":2,\n", " \"input_image\":\"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n", "\n", "}\n", "\n", "async_predict_fn(async_predictor,inputs_img2img)" ] }, { "cell_type": "markdown", "id": "31adc892", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 4. SageMaker endpoint AutoScaling Config(Option)" ] }, { "cell_type": "code", "execution_count": null, "id": "98cba429", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# application-autoscaling client\n", "asg_client = boto3.client(\"application-autoscaling\")\n", "\n", "# This is the format in which application autoscaling references the endpoint\n", "resource_id = f\"endpoint/{async_predictor.endpoint_name}/variant/AllTraffic\"\n", "\n", "# Configure Autoscaling on asynchronous endpoint down to zero instances\n", "response = asg_client.register_scalable_target(\n", " ServiceNamespace=\"sagemaker\",\n", " ResourceId=resource_id,\n", " ScalableDimension=\"sagemaker:variant:DesiredInstanceCount\",\n", " MinCapacity=1,\n", " MaxCapacity=2,\n", ")\n", "\n", "response = asg_client.put_scaling_policy(\n", " PolicyName=f'Request-ScalingPolicy-{async_predictor.endpoint_name}',\n", " ServiceNamespace=\"sagemaker\",\n", " ResourceId=resource_id,\n", " ScalableDimension=\"sagemaker:variant:DesiredInstanceCount\",\n", " PolicyType=\"TargetTrackingScaling\",\n", " TargetTrackingScalingPolicyConfiguration={\n", " \"TargetValue\": 2.0,\n", " \"CustomizedMetricSpecification\": {\n", " \"MetricName\": \"ApproximateBacklogSizePerInstance\",\n", " \"Namespace\": \"AWS/SageMaker\",\n", " \"Dimensions\": [{\"Name\": \"EndpointName\", \"Value\": async_predictor.endpoint_name}],\n", " \"Statistic\": \"Average\",\n", " },\n", " \"ScaleInCooldown\": 600, # duration until scale in begins (down to zero)\n", " \"ScaleOutCooldown\": 300 # duration between scale out attempts\n", " },\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "7a67ee27", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import time\n", "\n", "start = time.time()\n", "\n", "outputs=[]\n", "\n", "inputs_txt2img = {\n", " \"prompt\": \"a photo of an astronaut riding a horse on mars\",\n", " \"negative_prompt\":\"\",\n", " \"steps\":20,\n", " \"sampler\":\"euler_a\",\n", " \"seed\": 52362,\n", " \"height\": 512, \n", " \"width\": 512,\n", " \"count\":2\n", "\n", "}\n", "\n", "# send 10 requests\n", "for i in range(10):\n", " prediction = async_predictor.predict_async(inputs_txt2img)\n", " outputs.append(prediction)\n", "\n", "# iterate over list of output paths and get results\n", "results = []\n", "for output in outputs:\n", " response = output.get_result(WaiterConfig(max_attempts=600))\n", " results.append(response)\n", "\n", "print(f\"Time taken: {time.time() - start}s\")\n", "print(results)" ] }, { "cell_type": "markdown", "id": "fbc9da83-8e8c-4447-b82d-7b3405835d93", "metadata": {}, "source": [ "### draw result image" ] }, { "cell_type": "code", "execution_count": null, "id": "87db10a8-ef6d-459f-9691-1730324f9c5c", "metadata": {}, "outputs": [], "source": [ "for r in results:\n", " for item in r[\"result\"]:\n", " bucket, key = get_bucket_and_key(item)\n", " obj = s3_resource.Object(bucket, key)\n", " bytes = obj.get()['Body'].read()\n", " image = Image.open(io.BytesIO(bytes))\n", " image.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "2798ad59", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "response = asg_client.deregister_scalable_target(\n", " ServiceNamespace='sagemaker',\n", " ResourceId=resource_id,\n", " ScalableDimension='sagemaker:variant:DesiredInstanceCount'\n", ")\n" ] }, { "cell_type": "markdown", "id": "727f7363", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### 5. Clear resource " ] }, { "cell_type": "code", "execution_count": null, "id": "64da13e8", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "async_predictor.delete_endpoint()" ] }, { "cell_type": "markdown", "id": "ab93f85e", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# " ] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }