{ "cells": [ { "cell_type": "markdown", "id": "32271133", "metadata": {}, "source": [ "# Deploy ControlNet models to SageMaker endpoint\n", "\n", "To deploy the pretrained model to an SageMaker endpoint, we need to prepare the model artifacts and upload it to s3. When creating the endpoint, the model data will be downloaded during run time and stored in the path **\"/opt/ml/model\"**.\n", "\n", "
\n", "Note: This notebook was tested with the `conda_python3` kernel on an Amazon SageMaker notebook instance.\n", "
\n", "\n", "\n", "### Step 1: Prepare model artifacts" ] }, { "cell_type": "code", "execution_count": 12, "id": "c1fdd115", "metadata": {}, "outputs": [], "source": [ "import tarfile\n", "from pathlib import Path\n", "import tqdm\n", "from time import gmtime, strftime\n", "import datetime\n", "import boto3\n", "import json\n", "\n", "import sagemaker\n", "from sagemaker import get_execution_role\n", "from sagemaker.s3 import S3Uploader, s3_path_join\n", "\n", "from sagemaker import Model\n", "from sagemaker.serializers import JSONSerializer\n", "from sagemaker.deserializers import JSONDeserializer\n", "\n", "session = sagemaker.Session()\n", "bucket = session.default_bucket()\n", "prefix = \"sagemaker/generative-ai-controlnet\"\n", "role = get_execution_role()\n", "\n", "account = session.boto_session.client(\"sts\").get_caller_identity()[\"Account\"]\n", "region = session.boto_session.region_name" ] }, { "cell_type": "code", "execution_count": 2, "id": "03592281", "metadata": {}, "outputs": [], "source": [ "def create_tar(tarfile_name: str, local_path: Path):\n", " \"\"\"\n", " Create a tar.gz archive with the content of `local_path`.\n", " \"\"\"\n", " file_list = [k for k in local_path.glob(\"**/*.*\") if f\"{k.relative_to(local_path)}\"[0] != \".\"]\n", " pbar = tqdm.tqdm(file_list, unit=\"files\")\n", " with tarfile.open(tarfile_name, mode=\"w:gz\") as archive:\n", " for k in pbar:\n", " pbar.set_description(f\"{k}\")\n", " archive.add(k, arcname=f\"{k.relative_to(local_path)}\")\n", " tar_size = Path(tarfile_name).stat().st_size / 10**6\n", " return " ] }, { "cell_type": "markdown", "id": "f4561ec4", "metadata": {}, "source": [ "To create the tarball file, the below code might take more than 5mins due to the model size is quite big. Please be patient." ] }, { "cell_type": "code", "execution_count": null, "id": "abe68445", "metadata": {}, "outputs": [], "source": [ "%%time\n", "tar_file = \"model.tar.gz\"\n", "tar_size = create_tar(tar_file, Path(\"../models\"))\n", "print(f\"Created {tar_file}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "1eb0b5dc", "metadata": {}, "outputs": [], "source": [ "model_data_path = s3_path_join(\"s3://\", bucket, prefix + \"/models\")\n", "print(f\"Uploading Models to {model_data_path}\")\n", "model_uri = S3Uploader.upload(\"model.tar.gz\", model_data_path)\n", "print(f\"Uploaded roberta model to {model_uri}\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "d97ce0e6", "metadata": {}, "outputs": [], "source": [ "now = f\"{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}\"\n", "model_name = f\"sagemaker-gai-demo-{now}\"\n", "model = Model(\n", " image_uri=\"\", #replace with your own model_uri\n", " model_data=model_uri,\n", " name=model_name, \n", " role=role,\n", ")\n", "print(model_name)" ] }, { "cell_type": "code", "execution_count": 10, "id": "cb8467fb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-----------!" ] } ], "source": [ "endpoint_name = model_name\n", "\n", "model.deploy(\n", " initial_instance_count=1,\n", " instance_type='ml.g5.xlarge',\n", " serializer=JSONSerializer(),\n", " deserializer=JSONDeserializer(),\n", " endpoint_name=endpoint_name\n", ")\n", "%store endpoint_name" ] }, { "cell_type": "code", "execution_count": 13, "id": "ef08b2ea", "metadata": {}, "outputs": [], "source": [ "file_name = \"../data.json\"\n", "with open(file_name, 'r') as file:\n", " json_data = json.load(file)\n", " \n", "sm_runtime = boto3.client(\"sagemaker-runtime\")" ] }, { "cell_type": "code", "execution_count": 14, "id": "6b32e8f3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 44.6 ms, sys: 23.6 ms, total: 68.3 ms\n", "Wall time: 15.4 s\n" ] } ], "source": [ "%%time\n", "response = sm_runtime.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " Body=json.dumps(json_data),\n", " ContentType=\"application/json\",\n", ")\n", "data = response[\"Body\"].read().decode('utf-8')" ] }, { "cell_type": "code", "execution_count": 15, "id": "282cd528", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n" ] } ], "source": [ "output = json.loads(data)\n", "detected_map = output[\"detected_map\"]\n", "print(type(detected_map))\n", "image = output[\"image\"]\n", "print(type(image))" ] }, { "cell_type": "code", "execution_count": null, "id": "9461fc3b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }