{ "cells": [ { "cell_type": "markdown", "id": "872c9a95-4bed-40b5-a76a-2cf5982c1153", "metadata": {}, "source": [ "# Serve OpenAssistant Open-Assistant SFT-1 12B Model on Amazon SageMaker using LMI (Large Model Inference) DJL-based container\n", "**Recommended kernel(s):** This notebook can be run with any Amazon SageMaker Studio kernel.\n", "\n", "This notebook focuses on deploying the [`OpenAssistant/oasst-sft-1-pythia-12b`](https://huggingface.co/OpenAssistant/oasst-sft-1-pythia-12b) HuggingFace model to a SageMaker Endpoint for a text generation task. In this example, you will use the SageMaker-managed [LMI (Large Model Inference)](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-dlc.html) Docker image as inference image. LMI images features a [DJL serving](https://github.com/deepjavalibrary/djl-serving) stack powered by the [Deep Java Library](https://djl.ai/).\n", "\n", "You will successively deploy the `OpenAssistant/oasst-sft-1-pythia-12b` model twice on a `ml.g5.12xlarge` GPU instance (4 devices), once using the DeepSpeed inference handler, once using the HuggingFace Accelerate inference handler. This will allow you to compare the latency and the quality of the text generated by these two solutions.\n", "\n", "**Notices:**\n", "* Make sure that the `ml.g5.12xlarge` instance type is available in your AWS Region.\n", "* Make sure that the value of your \"ml.g5.12xlarge for endpoint usage\" Amazon SageMaker service quota allows you to deploy one Endpoint using this instance type.\n", "\n", "This notebook leverages the [`sagemaker` Python SDK](https://sagemaker.readthedocs.io/en/stable/index.html) to abstract away the management of as many resources and configuration as we can, hence demonstrating that the deployment of LLMs to SageMaker can be performed with great simplicity and minimal amount of code.\n", "\n", "### License agreement\n", "* This model and the dataset it has been trained on are both under the [Apache 2.0](https://huggingface.co/models?license=license:apache-2.0) license.\n", "* This notebook is a sample notebook and not intended for production use." ] }, { "cell_type": "markdown", "id": "fdaa8856-71c9-4333-b734-bd28b0071e08", "metadata": { "tags": [] }, "source": [ "### Execution environment setup\n", "This notebook requires the following third-party Python dependencies:\n", "* AWS [`boto3`](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html#)\n", "* AWS [`sagemaker`](https://sagemaker.readthedocs.io/en/stable/index.html), DJL support requires versions greater than 2.136.0 \n", "* HuggingFace [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/index)\n", "\n", "Let's install or upgrade these dependencies using the following command:" ] }, { "cell_type": "code", "execution_count": 2, "id": "869f0d19-e197-43ff-9b6a-9fb31d84816a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "spyder 5.3.3 requires pyqt5<5.16, which is not installed.\n", "spyder 5.3.3 requires pyqtwebengine<5.16, which is not installed.\n", "spyder 5.3.3 requires ipython<8.0.0,>=7.31.1, but you have ipython 8.12.0 which is incompatible.\n", "spyder 5.3.3 requires pylint<3.0,>=2.5.0, but you have pylint 3.0.0a6 which is incompatible.\n", "spyder-kernels 2.3.3 requires ipython<8,>=7.31.1; python_version >= \"3\", but you have ipython 8.12.0 which is incompatible.\n", "spyder-kernels 2.3.3 requires jupyter-client<8,>=7.3.4; python_version >= \"3\", but you have jupyter-client 8.1.0 which is incompatible.\n", "docker-compose 1.29.2 requires PyYAML<6,>=3.10, but you have pyyaml 6.0 which is incompatible.\n", "distributed 2022.7.0 requires tornado<6.2,>=6.0.3, but you have tornado 6.2 which is incompatible.\n", "awscli 1.27.111 requires botocore==1.29.111, but you have botocore 1.29.142 which is incompatible.\n", "awscli 1.27.111 requires PyYAML<5.5,>=3.10, but you have pyyaml 6.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.1.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install sagemaker huggingface_hub --upgrade --quiet" ] }, { "cell_type": "markdown", "id": "03ddf083-868f-4775-b2cb-2ff34dd5f8ca", "metadata": {}, "source": [ "### Imports & global variables assignment" ] }, { "cell_type": "code", "execution_count": 3, "id": "1697d3bb-040b-452e-a79e-9b17c717cca7", "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "import shutil\n", "from typing import Any, Dict, List\n", "\n", "import boto3\n", "import huggingface_hub\n", "import sagemaker" ] }, { "cell_type": "code", "execution_count": 4, "id": "8a9db6a2-2510-4844-b73f-fe050f52436b", "metadata": { "tags": [] }, "outputs": [], "source": [ "SM_DEFAULT_EXECUTION_ROLE_ARN = sagemaker.get_execution_role()\n", "SM_SESSION = sagemaker.session.Session()\n", "SM_ARTIFACT_BUCKET_NAME = SM_SESSION.default_bucket()\n", "\n", "REGION_NAME = SM_SESSION._region_name\n", "S3_CLIENT = boto3.client(\"s3\", region_name=REGION_NAME)" ] }, { "cell_type": "code", "execution_count": 5, "id": "61b34d1c-cf7a-4f7e-8091-215c52c3aba0", "metadata": { "tags": [] }, "outputs": [], "source": [ "HOME_DIR = os.environ[\"HOME\"]\n", "\n", "# HuggingFace local model storage\n", "HF_LOCAL_CACHE_DIR = Path(HOME_DIR) / \".cache\" / \"huggingface\" / \"hub\"\n", "HF_LOCAL_DOWNLOAD_DIR = Path.cwd() / \"model_repo\"\n", "HF_LOCAL_DOWNLOAD_DIR.mkdir(exist_ok=True)\n", "\n", "# Inference code local storage\n", "SOURCE_DIR = Path.cwd() / \"code\"\n", "SOURCE_DIR.mkdir(exist_ok=True)\n", "\n", "# Selected HuggingFace model\n", "HF_HUB_MODEL_NAME = \"OpenAssistant/oasst-sft-1-pythia-12b\"\n", "\n", "# HuggingFace remote model storage (Amazon S3)\n", "HF_MODEL_KEY_PREFIX = f\"hf-large-model-djl/{HF_HUB_MODEL_NAME}\"" ] }, { "cell_type": "markdown", "id": "91c7ba34-d9e0-4919-a6de-771f7d04d33b", "metadata": {}, "source": [ "### Storage utility functions" ] }, { "cell_type": "code", "execution_count": 6, "id": "5f4ad83d-62b5-4dfc-b800-cac6c61368ca", "metadata": { "tags": [] }, "outputs": [], "source": [ "def list_s3_objects(bucket: str, key_prefix: str) -> List[Dict[str, Any]]:\n", " paginator = S3_CLIENT.get_paginator(\"list_objects\")\n", " operation_parameters = {\"Bucket\": bucket, \"Prefix\": key_prefix}\n", " page_iterator = paginator.paginate(**operation_parameters)\n", " return [obj for page in page_iterator for obj in page[\"Contents\"]]\n", "\n", "\n", "def delete_s3_objects(bucket: str, keys: str) -> None:\n", " S3_CLIENT.delete_objects(Bucket=bucket, Delete={\"Objects\": [{\"Key\": key} for key in keys]})\n", "\n", "\n", "def get_local_model_cache_dir(hf_model_name: str) -> str:\n", " for dir_name in os.listdir(HF_LOCAL_CACHE_DIR):\n", " if dir_name.endswith(hf_model_name.replace(\"/\", \"--\")):\n", " break\n", " else:\n", " raise ValueError(f\"Could not find HF local cache directory for model {hf_model_name}\")\n", " return HF_LOCAL_CACHE_DIR / dir_name" ] }, { "cell_type": "markdown", "id": "59bbea82-c7b8-4ebd-bb9a-bc443c009f0e", "metadata": {}, "source": [ "### Inference utility functions\n", "Prompting the model requires marking the beginning and the end of the prompt with [special and model-specific tokens](https://huggingface.co/OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5#prompting). The following inference helper functions are used for all deployments." ] }, { "cell_type": "code", "execution_count": 7, "id": "538f2c6c-f90b-45fb-a4fb-8ea31f7914a3", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Model-specific tokens\n", "PROMPT_BOS_TOKEN = \"<|prompter|>\"\n", "PROMPT_EOS_TOKEN = \"<|endoftext|><|assistant|>\"\n", "\n", "\n", "def decorate_prompt(prompt: str) -> str:\n", " return f\"{PROMPT_BOS_TOKEN}{prompt}{PROMPT_EOS_TOKEN}\"" ] }, { "cell_type": "markdown", "id": "66276d81-05b6-4c06-9570-1e1047a68930", "metadata": { "tags": [] }, "source": [ "## 1. Model upload to Amazon S3\n", "Models served by a LMI container can be downloaded to the container in different ways:\n", "* Like all the SageMaker Inference containers, having the container to download the model from Amazon S3 as a single `model.tar.gz` file. In the case of LLMs, this approach is discouraged since downloading and decompression times can become unreasonably high.\n", "* Having the container to download the model directly from the HuggingFace Hub for you. This option may involve high download times too.\n", "* Having the container to download the uncompressed model from Amazon S3 with maximal throughput by using the [`s5cmd`](https://github.com/peak/s5cmd) utility. This option is specific to LMI containers and is the recommended one. It requires however, that the model has been previously uploaded to a S3 Bucket. \n", "\n", "In this section, you will:\n", "1. Download the model from the HuggingFace Hub to your local host,\n", "2. Upload the downloaded model to a S3 Bucket. This notebook uses the SageMaker's default regional Bucket. Feel free to upload the model to the Bucket of your choice by modifying the `SM_ARTIFACT_BUCKET_NAME` global variable accordingly.\n", "\n", "Each operation takes a few minutes." ] }, { "cell_type": "code", "execution_count": 8, "id": "d3b12738-ce43-4a02-bbf8-a3c4e591a762", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5643edc79f7740bb838f557d13986222", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 9 files: 0%| | 0/9 [00:00