{ "cells": [ { "cell_type": "markdown", "id": "1b2aebda", "metadata": {}, "source": [ "# BART Large model deployment on Amazon SageMaker Multi-model endpoints (MME) with GPU \n", "\n", "\n", "\n", "Amazon SageMaker multi-model endpoints(MME) provide a scalable and cost-effective way to deploy large number of deep learning models. Previously, customers had limited options to deploy 100s of deep learning models that need accelerated compute with GPUs. Now customers can deploy 1000s of deep learning models behind one SageMaker endpoint. Now, MME will run multiple models on a GPU, share GPU instances behind an endpoint across multiple models and dynamically load/unload models based on the incoming traffic. With this, customers can significantly save cost and achieve best price performance.\n", "\n", "\n", "\n", "
πŸ’‘ Note \n", "This notebook was tested with the `conda_python3` kernel on an Amazon SageMaker notebook instance of type `g5.xlarge`.\n", "
" ] }, { "cell_type": "markdown", "id": "96038ada", "metadata": {}, "source": [ "In this notebook, we will walk you through how to use NVIDIA Triton Inference Server on Amazon SageMaker MME with GPU feature to deploy a **BART** NLP model for **Translation**. " ] }, { "cell_type": "markdown", "id": "07d31e8a", "metadata": {}, "source": [ "## Installs\n", "\n", "Installs the dependencies required to package the model and run inferences using Triton server. Update SageMaker, boto3, awscli etc" ] }, { "cell_type": "code", "execution_count": 1, "id": "367c8d77", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com, https://pypi.ngc.nvidia.com, https://download.pytorch.org/whl/cu116\n", "Requirement already satisfied: torch in /home/ec2-user/anaconda3/envs/amazonei_pytorch_latest_p37/lib/python3.7/site-packages (1.5.1)\n", "Requirement already satisfied: future in /home/ec2-user/anaconda3/envs/amazonei_pytorch_latest_p37/lib/python3.7/site-packages (from torch) (0.18.2)\n", "Requirement already satisfied: numpy in /home/ec2-user/anaconda3/envs/amazonei_pytorch_latest_p37/lib/python3.7/site-packages (from torch) (1.21.6)\n" ] } ], "source": [ "!pip install -qU pip awscli boto3 sagemaker\n", "!pip install nvidia-pyindex --quiet\n", "!pip install tritonclient[http] --quiet\n", "!pip install transformers[sentencepiece] --quiet\n", "!pip install torch --extra-index-url https://download.pytorch.org/whl/cu116\n", "!pip install transformers --quiet" ] }, { "cell_type": "markdown", "id": "f29c9de0", "metadata": {}, "source": [ "## Imports and variables" ] }, { "cell_type": "code", "execution_count": 2, "id": "2386d882", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:22.12-py3\n" ] } ], "source": [ "import boto3, json, sagemaker, time\n", "from sagemaker import get_execution_role\n", "import numpy as np\n", "import os\n", "import json\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "\n", "# sagemaker variables\n", "role = get_execution_role()\n", "sm_client = boto3.client(service_name=\"sagemaker\")\n", "runtime_sm_client = boto3.client(\"sagemaker-runtime\")\n", "sagemaker_session = sagemaker.Session(boto_session=boto3.Session())\n", "s3_client = boto3.client('s3')\n", "bucket = sagemaker.Session().default_bucket()\n", "prefix = \"bart\"\n", "\n", "# account mapping for SageMaker MME Triton Image\n", "account_id_map = {\n", " \"us-east-1\": \"785573368785\",\n", " \"us-east-2\": \"007439368137\",\n", " \"us-west-1\": \"710691900526\",\n", " \"us-west-2\": \"301217895009\",\n", " \"eu-west-1\": \"802834080501\",\n", " \"eu-west-2\": \"205493899709\",\n", " \"eu-west-3\": \"254080097072\",\n", " \"eu-north-1\": \"601324751636\",\n", " \"eu-south-1\": \"966458181534\",\n", " \"eu-central-1\": \"746233611703\",\n", " \"ap-east-1\": \"110948597952\",\n", " \"ap-south-1\": \"763008648453\",\n", " \"ap-northeast-1\": \"941853720454\",\n", " \"ap-northeast-2\": \"151534178276\",\n", " \"ap-southeast-1\": \"324986816169\",\n", " \"ap-southeast-2\": \"355873309152\",\n", " \"cn-northwest-1\": \"474822919863\",\n", " \"cn-north-1\": \"472730292857\",\n", " \"sa-east-1\": \"756306329178\",\n", " \"ca-central-1\": \"464438896020\",\n", " \"me-south-1\": \"836785723513\",\n", " \"af-south-1\": \"774647643957\",\n", "}\n", "\n", "region = boto3.Session().region_name\n", "if region not in account_id_map.keys():\n", " raise (\"UNSUPPORTED REGION\")\n", "\n", "base = \"amazonaws.com.cn\" if region.startswith(\"cn-\") else \"amazonaws.com\"\n", "triton_image_uri = (\n", " \"{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.12-py3\".format(\n", " account_id=account_id_map[region], region=region, base=base\n", " )\n", ")\n", "print(triton_image_uri)" ] }, { "cell_type": "markdown", "id": "ce64b95c", "metadata": {}, "source": [ "## Workflow Overview\n", "\n", "This section presents overview of main steps for preparing a BART Pytorch model (served using Python backend) using Triton Inference Server.\n", "### 1. Generate Model Artifacts\n", "\n" ] }, { "cell_type": "markdown", "id": "a0f444f2", "metadata": {}, "source": [ "#### BART PyTorch Model\n", "\n", "In case of BART HuggingFace PyTorch Model, since we are serving it using Triton's [python backend](https://github.com/triton-inference-server/python_backend#usage) we have python script [model.py](./workspace/model.py) which implements all the logic to initialize the BART model and execute inference for the translation task." ] }, { "cell_type": "markdown", "id": "eb3dd938", "metadata": {}, "source": [ "### 2. Build Model Respository\n", "\n", "Using Triton on SageMaker requires us to first set up a [model repository](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_repository.md) folder containing the models we want to serve. For each model we need to create a model directory consisting of the model artifact and define config.pbtxt file to specify [model configuration](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md) which Triton uses to load and serve the model. \n", "\n" ] }, { "cell_type": "markdown", "id": "a39c4592", "metadata": {}, "source": [ "#### BART Python Backend Model\n", "\n", "Model repository structure for BART Model.\n", "\n", "```\n", "bart_pytorch\n", "β”œβ”€β”€ 1\n", "β”‚ └── model.py\n", "└── config.pbtxt\n", "```\n" ] }, { "cell_type": "markdown", "id": "93792c33", "metadata": {}, "source": [ "Next we set up the BART PyTorch Python Backend Model in the model repository:" ] }, { "cell_type": "code", "execution_count": 13, "id": "36131eb3-a172-4e26-99fd-4b917ec0507d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ec2-user/SageMaker\n" ] } ], "source": [ "!pwd" ] }, { "cell_type": "code", "execution_count": 14, "id": "4b6d6272", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mkdir: cannot create directory β€˜model_repository/bart_pytorch’: Permission denied\n", "cp: cannot stat β€˜workspace/model.py’: No such file or directory\n" ] } ], "source": [ "!mkdir -p model_repository/bart_pytorch/1\n", "!cp workspace/model.py model_repository/bart_pytorch/1/" ] }, { "cell_type": "markdown", "id": "a0a12da8", "metadata": {}, "source": [ "##### Create Conda Environment for Dependencies\n", "\n", "For serving the HuggingFace BART PyTorch Model using Triton's Python backend we have PyTorch and HuggingFace transformers as dependencies.\n", "\n", "We follow the instructions from the [Triton documentation for packaging dependencies](https://github.com/triton-inference-server/python_backend#2-packaging-the-conda-environment) to be used in the python backend as conda env tar file. Running the bash script [create_hf_env.sh]('./workspace/create_hf_env.sh') creates the conda environment containing PyTorch and HuggingFace transformers, packages it as tar file and then we move it into the bart-pytorch model directory. This can take a few minutes." ] }, { "cell_type": "code", "execution_count": 13, "id": "e203209f", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting package metadata (current_repodata.json): done\n", "Solving environment: done\n", "\n", "\n", "==> WARNING: A newer version of conda exists. <==\n", " current version: 22.9.0\n", " latest version: 23.1.0\n", "\n", "Please update conda by running\n", "\n", " $ conda update -n base -c conda-forge conda\n", "\n", "\n", "\n", "## Package Plan ##\n", "\n", " environment location: /home/ec2-user/anaconda3/envs/hf_env\n", "\n", " added / updated specs:\n", " - python=3.8\n", "\n", "\n", "The following packages will be downloaded:\n", "\n", " package | build\n", " ---------------------------|-----------------\n", " openssl-3.1.0 | h0b41bf4_0 2.5 MB conda-forge\n", " wheel-0.40.0 | pyhd8ed1ab_0 54 KB conda-forge\n", " ------------------------------------------------------------\n", " Total: 2.6 MB\n", "\n", "The following NEW packages will be INSTALLED:\n", "\n", " _libgcc_mutex conda-forge/linux-64::_libgcc_mutex-0.1-conda_forge None\n", " _openmp_mutex conda-forge/linux-64::_openmp_mutex-4.5-2_gnu None\n", " bzip2 conda-forge/linux-64::bzip2-1.0.8-h7f98852_4 None\n", " ca-certificates conda-forge/linux-64::ca-certificates-2022.12.7-ha878542_0 None\n", " ld_impl_linux-64 conda-forge/linux-64::ld_impl_linux-64-2.40-h41732ed_0 None\n", " libffi conda-forge/linux-64::libffi-3.4.2-h7f98852_5 None\n", " libgcc-ng conda-forge/linux-64::libgcc-ng-12.2.0-h65d4601_19 None\n", " libgomp conda-forge/linux-64::libgomp-12.2.0-h65d4601_19 None\n", " libnsl conda-forge/linux-64::libnsl-2.0.0-h7f98852_0 None\n", " libsqlite conda-forge/linux-64::libsqlite-3.40.0-h753d276_0 None\n", " libuuid conda-forge/linux-64::libuuid-2.32.1-h7f98852_1000 None\n", " libzlib conda-forge/linux-64::libzlib-1.2.13-h166bdaf_4 None\n", " ncurses conda-forge/linux-64::ncurses-6.3-h27087fc_1 None\n", " openssl conda-forge/linux-64::openssl-3.1.0-h0b41bf4_0 None\n", " pip conda-forge/noarch::pip-23.0.1-pyhd8ed1ab_0 None\n", " python conda-forge/linux-64::python-3.8.16-he550d4f_1_cpython None\n", " readline conda-forge/linux-64::readline-8.1.2-h0f457ee_0 None\n", " setuptools conda-forge/noarch::setuptools-67.6.0-pyhd8ed1ab_0 None\n", " tk conda-forge/linux-64::tk-8.6.12-h27826a3_0 None\n", " wheel conda-forge/noarch::wheel-0.40.0-pyhd8ed1ab_0 None\n", " xz conda-forge/linux-64::xz-5.2.6-h166bdaf_0 None\n", "\n", "\n", "\n", "Downloading and Extracting Packages\n", "wheel-0.40.0 | 54 KB | ##################################### | 100% \n", "openssl-3.1.0 | 2.5 MB | ##################################### | 100% \n", "Preparing transaction: done\n", "Verifying transaction: done\n", "Executing transaction: done\n", "#\n", "# To activate this environment, use\n", "#\n", "# $ conda activate hf_env\n", "#\n", "# To deactivate an active environment, use\n", "#\n", "# $ conda deactivate\n", "\n", "Retrieving notices: ...working... done\n", "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com, https://pypi.ngc.nvidia.com, https://download.pytorch.org/whl/cu116\n", "Collecting torch\n", " Downloading torch-2.0.0-cp38-cp38-manylinux1_x86_64.whl (619.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m619.9/619.9 MB\u001b[0m \u001b[31m248.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting triton==2.0.0\n", " Downloading https://download.pytorch.org/whl/triton-2.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (63.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.2/63.2 MB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting jinja2\n", " Downloading https://download.pytorch.org/whl/Jinja2-3.1.2-py3-none-any.whl (133 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.1/133.1 kB\u001b[0m \u001b[31m326.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting sympy\n", " Downloading https://download.pytorch.org/whl/sympy-1.11.1-py3-none-any.whl (6.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.5/6.5 MB\u001b[0m \u001b[31m31.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cublas-cu11==11.10.3.66\n", " Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m317.1/317.1 MB\u001b[0m \u001b[31m219.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-curand-cu11==10.2.10.91\n", " Downloading nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl (54.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.6/54.6 MB\u001b[0m \u001b[31m269.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting filelock\n", " Downloading filelock-3.10.0-py3-none-any.whl (9.9 kB)\n", "Collecting nvidia-cusparse-cu11==11.7.4.91\n", " Downloading nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl (173.2 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m173.2/173.2 MB\u001b[0m \u001b[31m240.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cudnn-cu11==8.5.0.96\n", " Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m557.1/557.1 MB\u001b[0m \u001b[31m261.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-nccl-cu11==2.14.3\n", " Downloading nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl (177.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m177.1/177.1 MB\u001b[0m \u001b[31m207.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cuda-nvrtc-cu11==11.7.99\n", " Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.0/21.0 MB\u001b[0m \u001b[31m182.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cusolver-cu11==11.4.0.1\n", " Downloading nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl (102.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.6/102.6 MB\u001b[0m \u001b[31m203.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-nvtx-cu11==11.7.91\n", " Downloading nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl (98 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.6/98.6 kB\u001b[0m \u001b[31m282.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting nvidia-cufft-cu11==10.9.0.58\n", " Downloading nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl (168.4 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.4/168.4 MB\u001b[0m \u001b[31m240.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting networkx\n", " Downloading https://download.pytorch.org/whl/networkx-3.0-py3-none-any.whl (2.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m36.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting typing-extensions\n", " Downloading typing_extensions-4.5.0-py3-none-any.whl (27 kB)\n", "Collecting nvidia-cuda-cupti-cu11==11.7.101\n", " Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl (11.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.8/11.8 MB\u001b[0m \u001b[31m277.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cuda-runtime-cu11==11.7.99\n", " Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m849.3/849.3 kB\u001b[0m \u001b[31m401.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: setuptools in /home/ec2-user/anaconda3/envs/hf_env/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (67.6.0)\n", "Requirement already satisfied: wheel in /home/ec2-user/anaconda3/envs/hf_env/lib/python3.8/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch) (0.40.0)\n", "Collecting lit\n", " Downloading https://download.pytorch.org/whl/lit-15.0.7.tar.gz (132 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.3/132.3 kB\u001b[0m \u001b[31m93.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hCollecting cmake\n", " Downloading cmake-3.26.0-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (24.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.0/24.0 MB\u001b[0m \u001b[31m206.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting MarkupSafe>=2.0\n", " Downloading https://download.pytorch.org/whl/MarkupSafe-2.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (25 kB)\n", "Collecting mpmath>=0.19\n", " Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m536.2/536.2 kB\u001b[0m \u001b[31m391.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hBuilding wheels for collected packages: lit\n", " Building wheel for lit (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for lit: filename=lit-15.0.7-py3-none-any.whl size=89988 sha256=d415e86aa7bf7f5e94f54b8ea338742c33053fefa7c4e71e8a1f56f2ec1c95da\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-c2gd51id/wheels/1e/70/26/70c3f20c3c449529659181fc15a0c3388bdaacd705152e3df7\n", "Successfully built lit\n", "Installing collected packages: mpmath, lit, cmake, typing-extensions, sympy, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, networkx, MarkupSafe, filelock, nvidia-cusolver-cu11, nvidia-cudnn-cu11, jinja2, triton, torch\n", "Successfully installed MarkupSafe-2.1.2 cmake-3.26.0 filelock-3.10.0 jinja2-3.1.2 lit-15.0.7 mpmath-1.3.0 networkx-3.0 nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-cupti-cu11-11.7.101 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.2.10.91 nvidia-cusolver-cu11-11.4.0.1 nvidia-cusparse-cu11-11.7.4.91 nvidia-nccl-cu11-2.14.3 nvidia-nvtx-cu11-11.7.91 sympy-1.11.1 torch-2.0.0 triton-2.0.0 typing-extensions-4.5.0\n", "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com, https://pypi.ngc.nvidia.com\n", "Collecting transformers[sentencepiece]\n", " Downloading transformers-4.27.1-py3-none-any.whl (6.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.7/6.7 MB\u001b[0m \u001b[31m100.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hRequirement already satisfied: filelock in /home/ec2-user/anaconda3/envs/hf_env/lib/python3.8/site-packages (from transformers[sentencepiece]) (3.10.0)\n", "Collecting requests\n", " Downloading requests-2.28.2-py3-none-any.whl (62 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.8/62.8 kB\u001b[0m \u001b[31m271.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1\n", " Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m274.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting packaging>=20.0\n", " Downloading packaging-23.0-py3-none-any.whl (42 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.7/42.7 kB\u001b[0m \u001b[31m266.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting numpy>=1.17\n", " Downloading numpy-1.24.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m17.3/17.3 MB\u001b[0m \u001b[31m308.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting huggingface-hub<1.0,>=0.11.0\n", " Downloading huggingface_hub-0.13.2-py3-none-any.whl (199 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.2/199.2 kB\u001b[0m \u001b[31m322.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting regex!=2019.12.17\n", " Downloading regex-2022.10.31-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (772 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m772.3/772.3 kB\u001b[0m \u001b[31m399.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting pyyaml>=5.1\n", " Downloading PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (701 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m701.2/701.2 kB\u001b[0m \u001b[31m399.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting tqdm>=4.27\n", " Downloading tqdm-4.65.0-py3-none-any.whl (77 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.1/77.1 kB\u001b[0m \u001b[31m298.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting sentencepiece!=0.1.92,>=0.1.91\n", " Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m394.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting protobuf<=3.20.2\n", " Downloading protobuf-3.20.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m387.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /home/ec2-user/anaconda3/envs/hf_env/lib/python3.8/site-packages (from huggingface-hub<1.0,>=0.11.0->transformers[sentencepiece]) (4.5.0)\n", "Collecting charset-normalizer<4,>=2\n", " Downloading charset_normalizer-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (195 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m195.9/195.9 kB\u001b[0m \u001b[31m371.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting urllib3<1.27,>=1.21.1\n", " Downloading urllib3-1.26.15-py2.py3-none-any.whl (140 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.9/140.9 kB\u001b[0m \u001b[31m338.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting idna<4,>=2.5\n", " Downloading idna-3.4-py3-none-any.whl (61 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.5/61.5 kB\u001b[0m \u001b[31m262.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting certifi>=2017.4.17\n", " Downloading certifi-2022.12.7-py3-none-any.whl (155 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m155.3/155.3 kB\u001b[0m \u001b[31m367.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: tokenizers, sentencepiece, urllib3, tqdm, regex, pyyaml, protobuf, packaging, numpy, idna, charset-normalizer, certifi, requests, huggingface-hub, transformers\n", "Successfully installed certifi-2022.12.7 charset-normalizer-3.1.0 huggingface-hub-0.13.2 idna-3.4 numpy-1.24.2 packaging-23.0 protobuf-3.20.2 pyyaml-6.0 regex-2022.10.31 requests-2.28.2 sentencepiece-0.1.97 tokenizers-0.13.2 tqdm-4.65.0 transformers-4.27.1 urllib3-1.26.15\n", "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com, https://pypi.ngc.nvidia.com\n", "Collecting conda-pack\n", " Downloading conda-pack-0.6.0.tar.gz (43 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: setuptools in /home/ec2-user/anaconda3/envs/hf_env/lib/python3.8/site-packages (from conda-pack) (67.6.0)\n", "Building wheels for collected packages: conda-pack\n", " Building wheel for conda-pack (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for conda-pack: filename=conda_pack-0.6.0-py2.py3-none-any.whl size=30883 sha256=e157663ca2d4b121802c8611a43df32c6d1b98fdfcae76d7de6560098eb9ae21\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-hboiduuk/wheels/56/1b/9e/0da27a4c18349d8f048a8fe87d763d75d3098384e9fa285e45\n", "Successfully built conda-pack\n", "Installing collected packages: conda-pack\n", "Successfully installed conda-pack-0.6.0\n", "Collecting packages...\n", "Packing environment at '/home/ec2-user/anaconda3/envs/hf_env' to 'hf_env.tar.gz'\n", "[########################################] | 100% Completed | 2min 54.3s\n" ] } ], "source": [ "!bash workspace/create_hf_env.sh\n", "!mv hf_env.tar.gz model_repository/bart_pytorch/" ] }, { "cell_type": "markdown", "id": "3363b6dc", "metadata": {}, "source": [ "After creating the tar file from the conda environment and placing it in model folder, you need to tell Python backend to use that environment for your model. We do this by including the lines below in the model `config.pbtxt` file:\n", "\n", "```\n", "parameters: {\n", " key: \"EXECUTION_ENV_PATH\",\n", " value: {string_value: \"$$TRITON_MODEL_DIRECTORY/hf_env.tar.gz\"}\n", "}\n", "```\n", "Here, `$$TRITON_MODEL_DIRECTORY` helps provide environment path relative to the model folder in model repository and is resolved to `$pwd/model_repository/bart_pytorch`. Finally `hf_env.tar.gz` is the name we gave to our conda env file." ] }, { "cell_type": "markdown", "id": "de6f3953", "metadata": {}, "source": [ "Now we are ready to define the config file for bart pytorch model being served through Triton's Python Backend:" ] }, { "cell_type": "code", "execution_count": 14, "id": "43dcaf95", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting model_repository/bart_pytorch/config.pbtxt\n" ] } ], "source": [ "%%writefile model_repository/bart_pytorch/config.pbtxt\n", "name: \"bart_pytorch\"\n", "backend: \"python\"\n", "max_batch_size: 8\n", "input: [\n", " {\n", " name: \"input_ids\"\n", " data_type: TYPE_INT32\n", " dims: [ -1 ]\n", " },\n", " {\n", " name: \"attention_mask\"\n", " data_type: TYPE_INT32\n", " dims: [ -1 ]\n", " }\n", "]\n", "output [\n", " {\n", " name: \"output\"\n", " data_type: TYPE_FP32\n", " dims: [ -1, -1 ]\n", " }\n", "]\n", "instance_group {\n", " count: 1\n", " kind: KIND_GPU\n", "}\n", "dynamic_batching {\n", "}\n", "parameters: {\n", " key: \"EXECUTION_ENV_PATH\",\n", " value: {string_value: \"$$TRITON_MODEL_DIRECTORY/hf_env.tar.gz\"}\n", "}" ] }, { "cell_type": "markdown", "id": "c3b7670a", "metadata": {}, "source": [ "### 3. Package models and upload to S3\n", "\n", "Next, we will package our model as `*.tar.gz` files for uploading to S3. " ] }, { "cell_type": "code", "execution_count": 86, "id": "2cb2e47e-13e4-4daa-bb7b-28b52f40e992", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ec2-user/SageMaker\n" ] } ], "source": [ "! pwd" ] }, { "cell_type": "code", "execution_count": 100, "id": "5b6bb061", "metadata": { "tags": [] }, "outputs": [], "source": [ "!tar -C BART-Triton-PyTorch/model_repository/ -czf BART-Triton-PyTorch/bart_pytorch_7.tar.gz bart_pytorch\n", "model_uri_bart_pytorch = sagemaker_session.upload_data(path=\"BART-Triton-PyTorch/bart_pytorch_7.tar.gz\", key_prefix=prefix)" ] }, { "cell_type": "markdown", "id": "2b690fe2", "metadata": {}, "source": [ "### 4. Create SageMaker Endpoint\n", "\n", "Now that we have uploaded the model artifacts to S3, we can create a SageMaker endpoint." ] }, { "cell_type": "markdown", "id": "0075cbb6", "metadata": {}, "source": [ "#### Define the serving container\n", "In the container definition, define the `ModelDataUrl` to specify the S3 directory that contains all the models that SageMaker multi-model endpoint will use to load and serve predictions. Set `Mode` to `MultiModel` to indicate SageMaker would create the endpoint with MME container specifications. We set the container with an image that supports deploying multi-model endpoints with GPU" ] }, { "cell_type": "code", "execution_count": 71, "id": "8681e13c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "s3://sagemaker-us-west-2-757967535041/bart/bart_pytorch.tar.gz\n", "s3://sagemaker-us-west-2-757967535041/bart/\n", "{'Image': '301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:22.12-py3', 'ModelDataUrl': 's3://sagemaker-us-west-2-757967535041/bart/', 'Mode': 'MultiModel'}\n" ] } ], "source": [ "print(model_uri_bart_pytorch)\n", "model_data_url = f\"s3://{bucket}/{prefix}/\"\n", "print(model_data_url)\n", "\n", "container = {\n", " \"Image\": triton_image_uri,\n", " \"ModelDataUrl\": model_data_url,\n", " \"Mode\": \"MultiModel\", \n", "}\n", "print(container)" ] }, { "cell_type": "markdown", "id": "a6c879df", "metadata": {}, "source": [ "#### Create a multi-model object" ] }, { "cell_type": "markdown", "id": "3a396ecb", "metadata": {}, "source": [ "Once the image, data location are set we create the model using `create_model` by specifying the `ModelName` and the Container definition" ] }, { "cell_type": "code", "execution_count": 59, "id": "9a22f650", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bart-2023-03-17-21-19-34\n" ] } ], "source": [ "ts = time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", "sm_model_name = f\"{prefix}-{ts}\"\n", "print(sm_model_name)" ] }, { "cell_type": "code", "execution_count": 60, "id": "dfa67ff7-42cb-4c78-b8ec-0c28d790edef", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model Arn: arn:aws:sagemaker:us-west-2:757967535041:model/bart-2023-03-17-21-19-34\n" ] } ], "source": [ "create_model_response = sm_client.create_model(\n", " ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container\n", ")\n", "\n", "print(\"Model Arn: \" + create_model_response[\"ModelArn\"])" ] }, { "cell_type": "markdown", "id": "a4e335d8", "metadata": {}, "source": [ "#### Define configuration for the multi-model endpoint" ] }, { "cell_type": "markdown", "id": "30d79536", "metadata": {}, "source": [ "Using the model above, we create an [endpoint configuration](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html) where we can specify the type and number of instances we want in the endpoint. Here we are deploying to `g5.2xlarge` NVIDIA GPU instance." ] }, { "cell_type": "code", "execution_count": 61, "id": "a9d4b510", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Endpoint Config Arn: arn:aws:sagemaker:us-west-2:757967535041:endpoint-config/bart-epc-2023-03-17-21-19-34\n" ] } ], "source": [ "endpoint_config_name = f\"{prefix}-epc-{ts}\"\n", "\n", "create_endpoint_config_response = sm_client.create_endpoint_config(\n", " EndpointConfigName=endpoint_config_name,\n", " ProductionVariants=[\n", " {\n", " \"InstanceType\": \"ml.g5.2xlarge\",\n", " \"InitialVariantWeight\": 1,\n", " \"InitialInstanceCount\": 1,\n", " \"ModelName\": sm_model_name,\n", " \"VariantName\": \"AllTraffic\",\n", " }\n", " ],\n", ")\n", "\n", "print(\"Endpoint Config Arn: \" + create_endpoint_config_response[\"EndpointConfigArn\"])\n" ] }, { "cell_type": "markdown", "id": "3a23980c", "metadata": {}, "source": [ "#### Create SageMaker Endpoint" ] }, { "cell_type": "markdown", "id": "52f050c5", "metadata": {}, "source": [ "Using the above endpoint configuration we create a new sagemaker endpoint and wait for the deployment to finish. The status will change to **InService** once the deployment is successful." ] }, { "cell_type": "code", "execution_count": 62, "id": "9b1b19a5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Endpoint Arn: arn:aws:sagemaker:us-west-2:757967535041:endpoint/bart-ep-2023-03-17-21-19-34\n", "endpointname: bart-ep-2023-03-17-21-19-34\n" ] } ], "source": [ "endpoint_name = f\"{prefix}-ep-{ts}\"\n", "\n", "create_endpoint_response = sm_client.create_endpoint(\n", " EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name\n", ")\n", "\n", "print(\"Endpoint Arn: \" + create_endpoint_response[\"EndpointArn\"])\n", "print(\"endpointname: \" + endpoint_name)" ] }, { "cell_type": "code", "execution_count": 63, "id": "e3bba0d5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Status: Creating\n", "Status: Creating\n", "Status: Creating\n", "Status: Creating\n", "Status: Creating\n", "Status: Creating\n", "Status: InService\n", "Arn: arn:aws:sagemaker:us-west-2:757967535041:endpoint/bart-ep-2023-03-17-21-19-34\n", "Status: InService\n" ] } ], "source": [ "resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n", "status = resp[\"EndpointStatus\"]\n", "print(\"Status: \" + status)\n", "\n", "while status == \"Creating\":\n", " time.sleep(60)\n", " resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n", " status = resp[\"EndpointStatus\"]\n", " print(\"Status: \" + status)\n", "\n", "print(\"Arn: \" + resp[\"EndpointArn\"])\n", "print(\"Status: \" + status)" ] }, { "cell_type": "markdown", "id": "b21403ea", "metadata": {}, "source": [ "### 5. Run Inference" ] }, { "cell_type": "markdown", "id": "6be758ed", "metadata": {}, "source": [ "Once we have the endpoint running we can use some sample raw data to do an inference using JSON as the payload format. For the inference request format, Triton uses the KFServing community standard [inference protocols](https://github.com/triton-inference-server/server/blob/main/docs/protocol/README.md)." ] }, { "cell_type": "markdown", "id": "a408192b", "metadata": {}, "source": [ "#### Add utility methods for preparing JSON request payload\n", "\n" ] }, { "cell_type": "markdown", "id": "59d65e6f", "metadata": {}, "source": [ "We'll use the following utility methods to convert our inference request for BART models into a json payload." ] }, { "cell_type": "code", "execution_count": 6, "id": "de500c8b-cc11-4993-a4bb-23cd2178caf7", "metadata": { "tags": [] }, "outputs": [], "source": [ "#helper functions\n", "import tritonclient.http as httpclient\n", "from transformers import BartTokenizer, BartModel\n", "from tritonclient.utils import *\n", "\n", "def get_tokenizer(model_name):\n", " tokenizer = BartTokenizer.from_pretrained(model_name)\n", " return tokenizer\n", "\n", "# def tokenize_text(model_name, text):\n", "# tokenizer = get_tokenizer(model_name)\n", "# tokenized_text = tokenizer(text, padding=True, return_tensors=\"pt\")\n", "# #tokenized_text = tokenizer(text)\n", "# return tokenized_text\n", "\n", "\n", "def tokenize_text(model_name, text):\n", " tokenizer = get_tokenizer(model_name)\n", " tokenized_text = tokenizer(text, padding=True, return_tensors=\"np\")\n", " return tokenized_text.input_ids, tokenized_text.attention_mask\n", "\n", "#V1\n", "# def get_text_payload(model_name, text):\n", " \n", "# inputs = []\n", "# outputs = []\n", " \n", "# inputs = tokenize_text(model_name, text)\n", " \n", "# text_obj = np.array(inputs[\"input_ids\"],dtype=np.int32).reshape(1,-1)\n", "# print(text_obj.shape)\n", "# input_text = httpclient.InferInput(\"input_ids\", text_obj.shape, np_to_triton_dtype(text_obj.dtype))\n", "# input_text.set_data_from_numpy(text_obj)\n", "# print(input_text)\n", "\n", "# attention_mask_obj = np.array(inputs[\"attention_mask\"], dtype=np.int32).reshape(1,-1)\n", "# print(attention_mask_obj.shape)\n", "# attention_mask = httpclient.InferInput(\"attention_mask\", attention_mask_obj.shape, np_to_triton_dtype(attention_mask_obj.dtype))\n", "# attention_mask.set_data_from_numpy(attention_mask_obj)\n", "# print(attention_mask)\n", " \n", "# inputs=[input_text, attention_mask]\n", "# return inputs\n", "\n", "#v2 \n", "\n", "def get_text_payload(model_name, text):\n", " input_ids, attention_mask = tokenize_text(model_name, text)\n", " payload = {}\n", " payload[\"inputs\"] = []\n", " payload[\"inputs\"].append({\"name\": \"input_ids\", \"shape\": input_ids.shape, \"datatype\": \"INT32\", \"data\": input_ids.tolist()})\n", " payload[\"inputs\"].append({\"name\": \"attention_mask\", \"shape\": attention_mask.shape, \"datatype\": \"INT32\", \"data\": attention_mask.tolist()})\n", " \n", " return payload\n", "\n", "# text_input = \"Hello, my dog is cute\"\n", "# bart_payload = get_text_payload('facebook/bart-large', text_input)\n", "\n", "# print(\"bart_payload is\", bart_payload)\n", "# print(\" payload type is\", type(bart_payload))\n", " " ] }, { "cell_type": "code", "execution_count": 4, "id": "edbcd3b1-9e22-4335-a741-c927b5dcc5d2", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bart-ep-2023-03-17-21-19-34\n", "1\n" ] } ], "source": [ "endpoint_name = \"bart-ep-2023-03-17-21-19-34\"\n", "print(endpoint_name)\n", "\n", "sm_client.describe_endpoint(EndpointName=endpoint_name)\n", "\n", "texts = [\"Hello, my dog is cute\"]\n", "batch_size = len(texts)\n", "print(batch_size)" ] }, { "cell_type": "code", "execution_count": 7, "id": "53c03d21-7056-41ea-9f86-a00d055190f1", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'inputs': [{'name': 'input_ids', 'shape': (1, 8), 'datatype': 'INT32', 'data': [[0, 31414, 6, 127, 2335, 16, 11962, 2]]}, {'name': 'attention_mask', 'shape': (1, 8), 'datatype': 'INT32', 'data': [[1, 1, 1, 1, 1, 1, 1, 1]]}]}\n" ] } ], "source": [ "bart_payload = get_text_payload('facebook/bart-large', texts)\n", "print(bart_payload)" ] }, { "cell_type": "code", "execution_count": 14, "id": "b1beedea-d6d3-4fc2-b0e4-809421e74d5d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1 Β΅s, sys: 0 ns, total: 1 Β΅s\n", "Wall time: 3.81 Β΅s\n", "{'ResponseMetadata': {'RequestId': '690c8bd4-5ccf-461e-a85b-b051de3845ea', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '690c8bd4-5ccf-461e-a85b-b051de3845ea', 'x-amzn-invoked-production-variant': 'AllTraffic', 'date': 'Fri, 17 Mar 2023 23:00:06 GMT', 'content-type': 'application/json', 'content-length': '162666'}, 'RetryAttempts': 0}, 'ContentType': 'application/json', 'InvokedProductionVariant': 'AllTraffic', 'Body': }\n" ] } ], "source": [ "%time\n", "\n", "response = runtime_sm_client.invoke_endpoint(\n", " EndpointName=endpoint_name,\n", " ContentType=\"application/octet-stream\",\n", " Body=json.dumps(bart_payload),\n", " TargetModel=\"bart_pytorch_7.tar.gz\",\n", ")\n", "print(response)" ] }, { "cell_type": "code", "execution_count": 9, "id": "a9d358c0-4ad9-469d-a13b-c1ecc0bd92af", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "response_body = json.loads(response[\"Body\"].read().decode(\"utf8\"))\n", "print(type(response_body))" ] }, { "cell_type": "code", "execution_count": 10, "id": "c20cb5bc-415f-4e17-b9b9-e3dc81f06af6", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'model_name': 'ccdfac5647468461e480c1c55bba575f',\n", " 'model_version': '1',\n", " 'outputs': [{'name': 'output',\n", " 'datatype': 'FP32',\n", " 'shape': [1, 8, 1024],\n", " 'data': [0.5512231588363647,\n", " 0.838931143283844,\n", " -1.4706687927246094,\n", " 0.36186015605926514,\n", " -0.16138087213039398,\n", " -0.7445544600486755,\n", " -0.5286368131637573,\n", " -0.9802274107933044,\n", " -1.5751498937606812,\n", " -0.1964043527841568,\n", " -0.6759628057479858,\n", " -1.1872944831848145,\n", " -0.9507772922515869,\n", " -0.4348897933959961,\n", " 0.2833757698535919,\n", " 0.3985210955142975,\n", " 0.5549306869506836,\n", " 0.05500718951225281,\n", " 0.1372915804386139,\n", " 0.20988517999649048,\n", " 4.532204627990723,\n", " -0.1279691457748413,\n", " 1.0849722623825073,\n", " -0.8920677304267883,\n", " -0.7142594456672668,\n", " 0.24808241426944733,\n", " -1.0593713521957397,\n", " -0.30108046531677246,\n", " 0.920939564704895,\n", " -0.11048457026481628,\n", " 0.3940015137195587,\n", " -0.6587729454040527,\n", " 0.9290228486061096,\n", " -0.04224908724427223,\n", " 0.42511940002441406,\n", " 0.22687682509422302,\n", " 0.2993665337562561,\n", " -1.4071598052978516,\n", " -1.0770210027694702,\n", " 1.0872652530670166,\n", " 0.5092843770980835,\n", " -0.6179194450378418,\n", " 0.5619251132011414,\n", " 0.569453775882721,\n", " -0.06784189492464066,\n", " 0.47741615772247314,\n", " -0.3446944057941437,\n", " 0.1075652539730072,\n", " -0.9379867315292358,\n", " 0.36186182498931885,\n", " 0.6265888214111328,\n", " 1.0445619821548462,\n", " -0.4787224233150482,\n", " 1.868227243423462,\n", " 1.0418145656585693,\n", " 0.6279901266098022,\n", " -0.9217061996459961,\n", " 1.2448848485946655,\n", " -0.714983344078064,\n", " 0.37466636300086975,\n", " -1.2213468551635742,\n", " -0.16280172765254974,\n", " 0.11499473452568054,\n", " 0.546456515789032,\n", " -0.10707467049360275,\n", " 0.07572869211435318,\n", " 0.42962026596069336,\n", " 0.10682890564203262,\n", " -1.208717942237854,\n", " 0.06795813143253326,\n", " -0.6115782260894775,\n", " 0.029875831678509712,\n", " -0.918994665145874,\n", " 0.33691564202308655,\n", " -0.6102118492126465,\n", " -1.5839966535568237,\n", " 0.8207632899284363,\n", " 1.1736699342727661,\n", " 0.5362817049026489,\n", " -0.05371475964784622,\n", " -0.6143125891685486,\n", " 0.49336665868759155,\n", " -0.6164402961730957,\n", " -0.5056082010269165,\n", " 0.7168957591056824,\n", " -0.8497525453567505,\n", " -0.2922948896884918,\n", " -0.8880093097686768,\n", " -0.002786717377603054,\n", " -0.05842962488532066,\n", " -1.2668131589889526,\n", " 0.17325131595134735,\n", " 1.3435301780700684,\n", " -0.34510329365730286,\n", " -0.5093274712562561,\n", " -0.9258503317832947,\n", " -0.45970603823661804,\n", " -0.6208155751228333,\n", " 1.3905047178268433,\n", " -0.14656469225883484,\n", " 0.20581774413585663,\n", " -0.05151316896080971,\n", " -0.16479533910751343,\n", " -0.5165272951126099,\n", " 0.46869802474975586,\n", " -0.884571373462677,\n", " -1.0196866989135742,\n", " -0.6487631797790527,\n", " -1.176162838935852,\n", " -0.6676568984985352,\n", " -0.8026613593101501,\n", " -0.22213353216648102,\n", " -0.9725930094718933,\n", " -1.33159601688385,\n", " 0.8548040390014648,\n", " -0.417388379573822,\n", " 0.016033776104450226,\n", " 0.12094976752996445,\n", " -0.9251769781112671,\n", " 0.15751433372497559,\n", " -0.7519498467445374,\n", " 0.09862756729125977,\n", " -0.4342701733112335,\n", " -1.1131422519683838,\n", " -0.6777624487876892,\n", " 0.4365592300891876,\n", " 0.2902430593967438,\n", " 0.8829609751701355,\n", " -1.5955195426940918,\n", " 0.09910336136817932,\n", " -0.18636171519756317,\n", " 0.8200868964195251,\n", " -4.091695308685303,\n", " 0.15947668254375458,\n", " 0.5788902640342712,\n", " -0.8765649795532227,\n", " -1.0144442319869995,\n", " -0.8207207322120667,\n", " 0.8372873067855835,\n", " -0.11059992760419846,\n", " -0.8043730854988098,\n", " 1.2638401985168457,\n", " -0.5774246454238892,\n", " -0.7697798013687134,\n", " 0.9235530495643616,\n", " 0.8054290413856506,\n", " 0.7480983734130859,\n", " -1.0924921035766602,\n", " -0.28845539689064026,\n", " 0.052280474454164505,\n", " 1.410509705543518,\n", " 1.0244117975234985,\n", " -0.2429094910621643,\n", " -0.42933133244514465,\n", " -1.8364536762237549,\n", " 0.004646077286452055,\n", " -0.1919497847557068,\n", " 0.8231711983680725,\n", " 0.09167972207069397,\n", " 1.2056597471237183,\n", " -0.7042230367660522,\n", " 0.19678810238838196,\n", " -1.0294041633605957,\n", " -0.3922756612300873,\n", " 1.3482087850570679,\n", " 0.04741837456822395,\n", " 0.5546521544456482,\n", " -0.0914994478225708,\n", " -0.7455064654350281,\n", " 0.03886833041906357,\n", " -0.2963472008705139,\n", " -1.0420867204666138,\n", " 0.05049479380249977,\n", " -0.07658731937408447,\n", " -0.17923863232135773,\n", " 0.4100443720817566,\n", " 0.677544891834259,\n", " 0.8397355079650879,\n", " 0.44486624002456665,\n", " 0.5043777823448181,\n", " -0.026700112968683243,\n", " -0.1588955521583557,\n", " -0.7928339242935181,\n", " -0.4780857563018799,\n", " 0.22858747839927673,\n", " -0.9639642238616943,\n", " 0.9826228022575378,\n", " 0.03179466351866722,\n", " 0.09886857867240906,\n", " -0.9862097501754761,\n", " 1.842560887336731,\n", " 0.45596835017204285,\n", " 1.3179372549057007,\n", " -1.7004516124725342,\n", " 0.6658682823181152,\n", " -0.17569716274738312,\n", " 0.05157418176531792,\n", " -0.3303687870502472,\n", " -0.2748239040374756,\n", " 0.029320525005459785,\n", " -0.17050664126873016,\n", " -0.8119512796401978,\n", " 0.1673360913991928,\n", " 0.09376450628042221,\n", " 0.2548041045665741,\n", " -0.025199836120009422,\n", " 1.3933025598526,\n", " -0.6395426988601685,\n", " -0.1061791405081749,\n", " -0.8862525820732117,\n", " 0.21441549062728882,\n", " -0.508711576461792,\n", " -0.19608967006206512,\n", " -0.9350410103797913,\n", " -0.25857239961624146,\n", " 0.20775391161441803,\n", " -0.6442614793777466,\n", " 1.1550129652023315,\n", " 0.061559878289699554,\n", " 1.5953881740570068,\n", " 0.71770179271698,\n", " 0.08726084232330322,\n", " -0.12460458278656006,\n", " -0.24504892528057098,\n", " -0.35453924536705017,\n", " -0.1654551774263382,\n", " -0.5172156095504761,\n", " -0.3406892716884613,\n", " -0.48230382800102234,\n", " 0.2531716823577881,\n", " 0.8236452341079712,\n", " 0.1264435350894928,\n", " 6.972285747528076,\n", " -0.4887571334838867,\n", " -0.3724238872528076,\n", " -0.15926115214824677,\n", " -0.8014334440231323,\n", " -0.7104360461235046,\n", " 0.69819575548172,\n", " -0.3236633539199829,\n", " -0.5053926706314087,\n", " -0.34379109740257263,\n", " -0.0028279402758926153,\n", " -0.5872572660446167,\n", " -0.3620743453502655,\n", " -1.0122774839401245,\n", " -0.5986558794975281,\n", " -0.5688014030456543,\n", " -0.6173795461654663,\n", " 0.5369263291358948,\n", " -0.6389930844306946,\n", " 0.3107251524925232,\n", " 0.33545544743537903,\n", " 0.41631919145584106,\n", " 0.039079658687114716,\n", " -0.3096712827682495,\n", " 0.4673421084880829,\n", " -0.3853227496147156,\n", " 0.6171102523803711,\n", " 0.5906476974487305,\n", " -0.7101634740829468,\n", " 0.9718310832977295,\n", " -0.036854252219200134,\n", " -0.47095969319343567,\n", " -0.07165536284446716,\n", " 0.1401498019695282,\n", " 0.5893962979316711,\n", " -0.8627554774284363,\n", " -0.35556304454803467,\n", " -0.1149238869547844,\n", " -0.44805067777633667,\n", " -1.0083986520767212,\n", " -1.478947401046753,\n", " 0.5835590362548828,\n", " -0.5158892869949341,\n", " -0.2609165906906128,\n", " -0.018535232171416283,\n", " -0.16808541119098663,\n", " -0.45145538449287415,\n", " 0.25417906045913696,\n", " 0.7631840109825134,\n", " 0.42941388487815857,\n", " 0.019418425858020782,\n", " -0.6942884922027588,\n", " -0.37361860275268555,\n", " 0.5396735668182373,\n", " -0.7488811612129211,\n", " -0.25243079662323,\n", " 0.18885239958763123,\n", " 0.5567657351493835,\n", " 0.210652157664299,\n", " -0.19232046604156494,\n", " -0.32969900965690613,\n", " -0.9545758366584778,\n", " 0.8496737480163574,\n", " 0.4456127882003784,\n", " -0.2512013018131256,\n", " 0.4754548966884613,\n", " -0.10384448617696762,\n", " -0.18178121745586395,\n", " 1.2048557996749878,\n", " -0.14131280779838562,\n", " -0.389619916677475,\n", " -0.4067803919315338,\n", " -0.2718769907951355,\n", " 0.900600016117096,\n", " -0.23106031119823456,\n", " 0.6260459423065186,\n", " -0.6941325664520264,\n", " -0.1837538629770279,\n", " -0.8757758140563965,\n", " -0.7358526587486267,\n", " 0.020555894821882248,\n", " -0.16650523245334625,\n", " 0.21734192967414856,\n", " -0.3569191098213196,\n", " 1.0339720249176025,\n", " 0.3702871799468994,\n", " 0.6024506688117981,\n", " -0.18002048134803772,\n", " -0.12898455560207367,\n", " -0.3414952754974365,\n", " 0.15872502326965332,\n", " -0.3732900619506836,\n", " 0.7711796164512634,\n", " -0.6028750538825989,\n", " 0.6459299921989441,\n", " 0.8036919832229614,\n", " -0.26192593574523926,\n", " 0.0019996652845293283,\n", " 0.2527709901332855,\n", " 0.014281121082603931,\n", " -0.4867848753929138,\n", " -0.2653316259384155,\n", " -0.4558074176311493,\n", " -0.37890052795410156,\n", " -0.6719330549240112,\n", " -1.9447945356369019,\n", " -0.19239665567874908,\n", " -0.7747593522071838,\n", " -0.7002098560333252,\n", " -0.4992590546607971,\n", " 0.07531926780939102,\n", " -1.9907246828079224,\n", " -0.14587625861167908,\n", " -0.014557763002812862,\n", " 0.025693826377391815,\n", " -0.5343532562255859,\n", " -0.993387758731842,\n", " 0.19623781740665436,\n", " -1.228711724281311,\n", " -0.7568314671516418,\n", " 0.22933663427829742,\n", " 0.19189424812793732,\n", " 0.2715136706829071,\n", " 0.48874419927597046,\n", " 0.19881778955459595,\n", " 0.03700733184814453,\n", " -1.0321294069290161,\n", " -0.4323597848415375,\n", " -0.026666544377803802,\n", " 0.11166931688785553,\n", " -0.1199464499950409,\n", " 0.48179200291633606,\n", " 0.035837117582559586,\n", " -0.1323581337928772,\n", " 0.9668102264404297,\n", " -0.23024417459964752,\n", " -0.18630985915660858,\n", " -1.4463317394256592,\n", " 0.35784995555877686,\n", " 0.46143391728401184,\n", " 0.43917006254196167,\n", " -0.31483787298202515,\n", " -0.19685734808444977,\n", " 0.48835471272468567,\n", " -0.7141571044921875,\n", " -0.3362944424152374,\n", " -0.08587032556533813,\n", " -0.18535959720611572,\n", " -8.615423202514648,\n", " -0.17529499530792236,\n", " -0.9045871496200562,\n", " -0.1934724599123001,\n", " 0.36809250712394714,\n", " -0.396505206823349,\n", " 0.8016916513442993,\n", " 0.06793776899576187,\n", " -0.6020064949989319,\n", " 0.6324535608291626,\n", " 1.283007264137268,\n", " -0.1582910120487213,\n", " 0.0036173320841044188,\n", " 0.4031180739402771,\n", " -0.1651155799627304,\n", " 0.3803315758705139,\n", " 0.2089451551437378,\n", " -0.7576873302459717,\n", " 0.042371977120637894,\n", " -0.10666054487228394,\n", " -0.5437188744544983,\n", " -0.08425711840391159,\n", " -0.21393167972564697,\n", " 0.11725594103336334,\n", " 0.11253275722265244,\n", " 0.5823022723197937,\n", " -1.0295616388320923,\n", " -1.3030933141708374,\n", " 0.6189414858818054,\n", " 0.16979698836803436,\n", " -0.1496264785528183,\n", " 0.3635156452655792,\n", " 0.3524884283542633,\n", " 0.32144322991371155,\n", " -0.041253820061683655,\n", " -0.900516927242279,\n", " 0.20737408101558685,\n", " 1.7921885251998901,\n", " 0.5112811326980591,\n", " -0.38624605536460876,\n", " -0.5306797027587891,\n", " -0.3516419529914856,\n", " 0.04458726942539215,\n", " -0.7270622849464417,\n", " 1.2166049480438232,\n", " 0.9357081055641174,\n", " -0.3513287901878357,\n", " -0.8089878559112549,\n", " 0.28889426589012146,\n", " -0.6153973340988159,\n", " -0.6835390329360962,\n", " 0.8324691653251648,\n", " 0.09520367532968521,\n", " 0.6589880585670471,\n", " -0.6048147082328796,\n", " -0.06995150446891785,\n", " -0.2670648992061615,\n", " 0.1258706897497177,\n", " -0.19142137467861176,\n", " 0.46072497963905334,\n", " -0.16879430413246155,\n", " 0.08958382904529572,\n", " 0.16187810897827148,\n", " 0.2220761477947235,\n", " 0.8587404489517212,\n", " 0.36324241757392883,\n", " 1.4122816324234009,\n", " 0.573714017868042,\n", " 0.4746886193752289,\n", " -0.9183347821235657,\n", " -0.8210547566413879,\n", " 0.1305522471666336,\n", " 0.6481757760047913,\n", " -0.060427192598581314,\n", " 1.0081449747085571,\n", " -0.14773046970367432,\n", " -0.031227707862854004,\n", " 0.026498062536120415,\n", " 0.1277473270893097,\n", " 0.5265172123908997,\n", " -0.7398378849029541,\n", " 0.17436881363391876,\n", " -0.4923505485057831,\n", " -1.17972993850708,\n", " 0.021447140723466873,\n", " 0.7504041790962219,\n", " -0.6090976595878601,\n", " 0.06697836518287659,\n", " 1.2524203062057495,\n", " -0.8106833696365356,\n", " -0.5760090947151184,\n", " -0.7946833968162537,\n", " 0.48524028062820435,\n", " -1.4787696599960327,\n", " -0.7142453193664551,\n", " -0.024569617584347725,\n", " 0.8599358797073364,\n", " -0.5616448521614075,\n", " -0.06197698786854744,\n", " 0.17722056806087494,\n", " -0.016535663977265358,\n", " -0.47884270548820496,\n", " -1.480549931526184,\n", " -1.1657813787460327,\n", " -0.45706549286842346,\n", " 0.6471889615058899,\n", " 0.38626059889793396,\n", " 0.3486173450946808,\n", " 0.9241748452186584,\n", " 0.8687440752983093,\n", " 0.9630197286605835,\n", " -0.7648880481719971,\n", " -0.27462220191955566,\n", " 0.012058877386152744,\n", " -0.58051598072052,\n", " 0.3752920925617218,\n", " 0.30081912875175476,\n", " -0.13619989156723022,\n", " -0.6178871989250183,\n", " -0.49357545375823975,\n", " 0.6037430167198181,\n", " -0.3269115388393402,\n", " -0.6070903539657593,\n", " 0.10798647254705429,\n", " 0.8116581439971924,\n", " 1.574215292930603,\n", " -0.42597460746765137,\n", " 0.004487248603254557,\n", " -0.6902196407318115,\n", " 2.1378681659698486,\n", " -0.12691238522529602,\n", " -0.9636858701705933,\n", " -0.2983678877353668,\n", " 0.1772928535938263,\n", " 0.0756492093205452,\n", " 0.26133373379707336,\n", " -0.4140305817127228,\n", " -0.22501759231090546,\n", " -0.0977892130613327,\n", " 0.11535117030143738,\n", " -0.2738235294818878,\n", " 0.7809329032897949,\n", " -1.091817021369934,\n", " -0.08913439512252808,\n", " -0.9400854706764221,\n", " -0.17109379172325134,\n", " 0.11545951664447784,\n", " 0.39872676134109497,\n", " -0.1231875866651535,\n", " -0.3932957053184509,\n", " -0.9756698608398438,\n", " -0.5209663510322571,\n", " 0.032111573964357376,\n", " 0.5972258448600769,\n", " -0.8796707391738892,\n", " -0.06512387841939926,\n", " -0.249892458319664,\n", " -1.0364198684692383,\n", " -0.14316511154174805,\n", " -0.14378884434700012,\n", " 0.7785264253616333,\n", " -1.371031403541565,\n", " 0.3373274803161621,\n", " -0.8353373408317566,\n", " -0.03164541348814964,\n", " -0.6169877052307129,\n", " -0.3890879452228546,\n", " 0.8953436017036438,\n", " -0.21534018218517303,\n", " 0.1442190408706665,\n", " 0.11236201971769333,\n", " 0.5319151878356934,\n", " -0.2954384684562683,\n", " 0.23129108548164368,\n", " -0.5865090489387512,\n", " 0.7397150993347168,\n", " -0.20796160399913788,\n", " 0.1412460058927536,\n", " -0.3924434185028076,\n", " -0.22321505844593048,\n", " 0.6155570149421692,\n", " -0.40110334753990173,\n", " -0.10925525426864624,\n", " -0.32230740785598755,\n", " 0.5831630229949951,\n", " -0.11093272268772125,\n", " 0.2304338961839676,\n", " -0.9112204313278198,\n", " 0.17969875037670135,\n", " -1.0478620529174805,\n", " 0.5550486445426941,\n", " -0.7261440753936768,\n", " -1.0309748649597168,\n", " -0.1891491413116455,\n", " 0.1380051225423813,\n", " -0.338050901889801,\n", " 0.4300931394100189,\n", " -0.05921455845236778,\n", " -0.06361677497625351,\n", " 0.10510098189115524,\n", " 0.015387043356895447,\n", " -0.6834678053855896,\n", " -1.8665062189102173,\n", " 0.37427833676338196,\n", " -0.3336694538593292,\n", " -1.1239211559295654,\n", " -0.25503188371658325,\n", " 0.08487388491630554,\n", " 0.3096638023853302,\n", " -0.19361235201358795,\n", " 0.4022916853427887,\n", " 0.25155678391456604,\n", " 0.040782198309898376,\n", " -0.6731263399124146,\n", " -0.5131261944770813,\n", " -3.096902847290039,\n", " 0.06780391931533813,\n", " 1.0685887336730957,\n", " -0.17457181215286255,\n", " -0.9600634574890137,\n", " -1.480100393295288,\n", " -0.0808100551366806,\n", " -0.15465661883354187,\n", " 0.30859822034835815,\n", " -0.5793341398239136,\n", " -0.29521510004997253,\n", " -0.5119758248329163,\n", " 0.25260502099990845,\n", " 14.806681632995605,\n", " 0.6090054512023926,\n", " -0.46446526050567627,\n", " 0.2754199504852295,\n", " -0.7489909529685974,\n", " 0.3026597797870636,\n", " -0.5859566926956177,\n", " 0.6701865792274475,\n", " 2.08671498298645,\n", " 0.7813906073570251,\n", " -0.6511877179145813,\n", " 0.24493516981601715,\n", " -0.7520231604576111,\n", " -0.5965690016746521,\n", " 0.49424126744270325,\n", " 1.0625125169754028,\n", " -0.45748764276504517,\n", " 0.00562717579305172,\n", " 0.3482455611228943,\n", " 0.11113521456718445,\n", " -0.9454023838043213,\n", " -1.0063183307647705,\n", " -0.26960334181785583,\n", " -0.3617252707481384,\n", " 0.4135421812534332,\n", " 0.0784333199262619,\n", " 0.29563814401626587,\n", " 1.024407148361206,\n", " -0.6876221299171448,\n", " -0.06981656700372696,\n", " 0.7138344645500183,\n", " 0.7423627972602844,\n", " 0.2800251841545105,\n", " -0.5719850063323975,\n", " 0.27031826972961426,\n", " -0.2722645103931427,\n", " 0.26073387265205383,\n", " -1.0050592422485352,\n", " -0.3944884240627289,\n", " -0.8007617592811584,\n", " -0.20162196457386017,\n", " -0.40813207626342773,\n", " -0.4994469881057739,\n", " -0.6247358322143555,\n", " 0.14737221598625183,\n", " 0.5043387413024902,\n", " -0.699865996837616,\n", " -0.02449065074324608,\n", " 1.5786608457565308,\n", " 0.43364307284355164,\n", " 1.1439844369888306,\n", " -0.5856248736381531,\n", " -0.7980527877807617,\n", " 0.5757154226303101,\n", " 0.036841511726379395,\n", " 0.333440899848938,\n", " 1.0385944843292236,\n", " -0.6617634892463684,\n", " 0.3963220417499542,\n", " 0.7422211766242981,\n", " -0.38160935044288635,\n", " -0.7841805815696716,\n", " 1.1549186706542969,\n", " 0.0601218082010746,\n", " 0.5549145340919495,\n", " 0.17364224791526794,\n", " -0.5826998949050903,\n", " -0.9026727676391602,\n", " -0.20420794188976288,\n", " -0.05808091163635254,\n", " 0.247501403093338,\n", " -0.16053634881973267,\n", " -0.6229061484336853,\n", " -0.37714883685112,\n", " -0.9221987128257751,\n", " 0.1428336799144745,\n", " -0.7975381016731262,\n", " -0.7336996793746948,\n", " -0.7335259914398193,\n", " -0.4215308427810669,\n", " 0.8862835168838501,\n", " 0.17378829419612885,\n", " -0.4076822102069855,\n", " 0.45626503229141235,\n", " -1.2035794258117676,\n", " -0.6900883316993713,\n", " -0.7599830031394958,\n", " 0.4042466878890991,\n", " -0.03392999246716499,\n", " 0.06922212243080139,\n", " -0.22810307145118713,\n", " -0.16998258233070374,\n", " 1.074487328529358,\n", " -0.1837942749261856,\n", " -0.013951834291219711,\n", " 0.1096404641866684,\n", " 0.3501183092594147,\n", " -0.8792479634284973,\n", " -2.0086472034454346,\n", " -0.20247526466846466,\n", " 0.6755514740943909,\n", " 0.2014104425907135,\n", " 0.10746446996927261,\n", " 0.48227766156196594,\n", " 0.729827344417572,\n", " 0.038770582526922226,\n", " -0.26406076550483704,\n", " -0.8906861543655396,\n", " 0.8175033330917358,\n", " -0.100845567882061,\n", " -0.8431618213653564,\n", " -0.9389227032661438,\n", " 0.3473314642906189,\n", " -0.48647403717041016,\n", " 0.4390827417373657,\n", " -0.29392558336257935,\n", " -0.6277146339416504,\n", " -1.0374670028686523,\n", " -0.6938822865486145,\n", " -0.5487024188041687,\n", " 1.4348708391189575,\n", " 0.5960367918014526,\n", " -1.3844335079193115,\n", " -0.4452952444553375,\n", " 0.8897888660430908,\n", " -1.1517540216445923,\n", " -0.04761616885662079,\n", " -0.3975827693939209,\n", " -0.18899205327033997,\n", " 0.05176090821623802,\n", " 0.5536333322525024,\n", " 0.060014013200998306,\n", " -0.2731142044067383,\n", " -0.6213244199752808,\n", " 0.0019015823490917683,\n", " -0.07379116117954254,\n", " 0.2752392292022705,\n", " 0.3508123755455017,\n", " -0.772514283657074,\n", " -0.27574068307876587,\n", " 0.5343725681304932,\n", " -0.35206612944602966,\n", " -0.20220153033733368,\n", " -1.9407036304473877,\n", " -0.02628588303923607,\n", " -0.28972572088241577,\n", " 0.28028538823127747,\n", " -0.5500861406326294,\n", " -0.17594845592975616,\n", " 0.4839322566986084,\n", " -1.369885802268982,\n", " -0.34468820691108704,\n", " -0.13607299327850342,\n", " -0.5910543203353882,\n", " -0.611413300037384,\n", " -0.27293211221694946,\n", " 0.2840813100337982,\n", " 0.17735804617404938,\n", " -1.1107754707336426,\n", " 0.21199874579906464,\n", " 0.9442760348320007,\n", " 1.1258407831192017,\n", " 0.7009125351905823,\n", " 0.10942572355270386,\n", " 0.7747233510017395,\n", " -0.08523924648761749,\n", " -0.48948904871940613,\n", " -0.0839976817369461,\n", " -0.09486719220876694,\n", " -0.10165372490882874,\n", " 0.7103846073150635,\n", " 1.9819456338882446,\n", " 0.29031047224998474,\n", " -0.5070116519927979,\n", " 0.21672241389751434,\n", " -0.48319151997566223,\n", " -0.48705288767814636,\n", " -0.05540275573730469,\n", " -0.29413890838623047,\n", " -1.5628045797348022,\n", " -0.9802005290985107,\n", " -0.10558260977268219,\n", " 0.7522594332695007,\n", " 1.0925337076187134,\n", " -2.6662333011627197,\n", " -0.3537704646587372,\n", " -0.29072850942611694,\n", " -0.38659220933914185,\n", " -0.07942889630794525,\n", " -0.5778862833976746,\n", " -0.04945552721619606,\n", " 0.5479450225830078,\n", " -0.08469293266534805,\n", " -0.13082793354988098,\n", " 0.33543476462364197,\n", " 0.4664786159992218,\n", " -0.058491021394729614,\n", " -0.48161348700523376,\n", " 0.2762199938297272,\n", " -1.1916759014129639,\n", " -0.3293285369873047,\n", " -0.4055858254432678,\n", " -0.1673295646905899,\n", " -0.5898811221122742,\n", " 0.20569205284118652,\n", " 1.062270164489746,\n", " 0.6869171261787415,\n", " -0.5571070313453674,\n", " -0.58970046043396,\n", " -0.360991895198822,\n", " 0.5100704431533813,\n", " -0.14549565315246582,\n", " -1.1706228256225586,\n", " -0.04149194061756134,\n", " 0.2038690447807312,\n", " -0.13545048236846924,\n", " 0.5271217823028564,\n", " -0.46835044026374817,\n", " -1.394780158996582,\n", " -1.281163215637207,\n", " -0.45064106583595276,\n", " -0.44844353199005127,\n", " 0.4811713397502899,\n", " 0.5549451112747192,\n", " 0.4399583637714386,\n", " 0.054237160831689835,\n", " -0.058604151010513306,\n", " 0.4888751208782196,\n", " -0.20230494439601898,\n", " 0.18096591532230377,\n", " -0.513340413570404,\n", " -0.1102394238114357,\n", " -0.3898923993110657,\n", " -0.6009892821311951,\n", " -0.725178599357605,\n", " -0.05625082179903984,\n", " -0.8542277216911316,\n", " -0.6884782314300537,\n", " -0.7019779682159424,\n", " -0.11025514453649521,\n", " 0.4197993576526642,\n", " 0.8522196412086487,\n", " 0.08304435014724731,\n", " 0.19630710780620575,\n", " 0.7074745297431946,\n", " -0.015370987355709076,\n", " 0.037080660462379456,\n", " -0.48519477248191833,\n", " 0.5797102451324463,\n", " -0.4608462452888489,\n", " -0.1773066222667694,\n", " 0.7433547973632812,\n", " -0.2954903841018677,\n", " 0.09348908066749573,\n", " 0.7225752472877502,\n", " -0.27521440386772156,\n", " 0.2525806128978729,\n", " -0.19455070793628693,\n", " -0.06710036844015121,\n", " -0.2522304356098175,\n", " -0.5218116044998169,\n", " 0.20594820380210876,\n", " -0.2851426899433136,\n", " -0.13955985009670258,\n", " -0.682921826839447,\n", " 0.48325657844543457,\n", " 0.2163250744342804,\n", " 0.11227460950613022,\n", " -0.6042129993438721,\n", " -1.2690485715866089,\n", " -0.029808329418301582,\n", " -0.13022933900356293,\n", " -0.06534045934677124,\n", " -0.7898696064949036,\n", " -1.067283272743225,\n", " -0.9876907467842102,\n", " 0.5449464321136475,\n", " 0.20234310626983643,\n", " -1.4014314413070679,\n", " -0.28375041484832764,\n", " 0.4547242224216461,\n", " 0.596598207950592,\n", " -0.1394488662481308,\n", " -0.2914172410964966,\n", " 0.6631377935409546,\n", " 0.44422823190689087,\n", " -0.09085384011268616,\n", " -0.43811509013175964,\n", " 0.3088453710079193,\n", " 0.060758139938116074,\n", " 0.4341190755367279,\n", " -0.8412942290306091,\n", " 0.0020509406458586454,\n", " -0.8093971610069275,\n", " -0.646979808807373,\n", " 1.1572284698486328,\n", " 0.4596129357814789,\n", " 0.21434593200683594,\n", " 0.381754606962204,\n", " 0.23834285140037537,\n", " -1.4315060377120972,\n", " -0.42201924324035645,\n", " 0.35704541206359863,\n", " 0.3821445107460022,\n", " -0.5810056924819946,\n", " 0.05293489620089531,\n", " -0.4321689307689667,\n", " 0.2871212959289551,\n", " -0.18105527758598328,\n", " -1.0111305713653564,\n", " 0.5809206366539001,\n", " 0.18537519872188568,\n", " -1.281412959098816,\n", " -0.5020036101341248,\n", " 0.6866092681884766,\n", " -1.7570589780807495,\n", " 0.605182945728302,\n", " 3.3644421100616455,\n", " 0.08233975619077682,\n", " -1.2022086381912231,\n", " -0.48621106147766113,\n", " 0.2822191119194031,\n", " -1.2259600162506104,\n", " -0.573685884475708,\n", " 0.5537064671516418,\n", " -0.7553575038909912,\n", " 0.056801341474056244,\n", " -0.5218964219093323,\n", " 0.724176824092865,\n", " 1.9177591800689697,\n", " -0.7821144461631775,\n", " 0.7417442798614502,\n", " -0.4669231176376343,\n", " -0.9275730848312378,\n", " -0.6659315824508667,\n", " -0.5725175738334656,\n", " -0.6109191179275513,\n", " 0.2824961245059967,\n", " -0.4488896429538727,\n", " 0.19218841195106506,\n", " 0.5667194128036499,\n", " -0.10712968558073044,\n", " 0.02623986080288887,\n", " 0.5902069807052612,\n", " -0.04075873643159866,\n", " -0.21326790750026703,\n", " 0.46892568469047546,\n", " -0.6187081933021545,\n", " -0.1876203715801239,\n", " -0.37642061710357666,\n", " 0.07667309045791626,\n", " -0.026581645011901855,\n", " 0.41338154673576355,\n", " 0.47654059529304504,\n", " 0.10144391655921936,\n", " 0.3112592399120331,\n", " -0.8311785459518433,\n", " -0.9617242813110352,\n", " -0.35983240604400635,\n", " -0.11548387259244919,\n", " -1.0810707807540894,\n", " -0.2783816456794739,\n", " 0.4145863354206085,\n", " -0.20730377733707428,\n", " -0.9726863503456116,\n", " 0.47991302609443665,\n", " 0.06358105689287186,\n", " -0.20456013083457947,\n", " 0.13724663853645325,\n", " 0.6748533844947815,\n", " 0.6572888493537903,\n", " -0.09440459311008453,\n", " -0.22326742112636566,\n", " -0.8103961944580078,\n", " -0.515312135219574,\n", " -0.2700837552547455,\n", " 0.8381712436676025,\n", " -0.2831617593765259,\n", " -0.40148037672042847,\n", " 0.45399364829063416,\n", " -0.14649711549282074,\n", " -0.24993129074573517,\n", " -1.934096336364746,\n", " 0.07009963691234589,\n", " -1.0383573770523071,\n", " 0.21950677037239075,\n", " 0.9118907451629639,\n", " -0.8030053973197937,\n", " -0.6415857076644897,\n", " -0.09130054712295532,\n", " 0.03330611065030098,\n", " 0.8294482231140137,\n", " ...]}]}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "response_body" ] }, { "cell_type": "code", "execution_count": 37, "id": "fb6bfe55-e817-4c35-994a-07ad3c8afe8c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'input_ids': [0, 31414, 6, 127, 2335, 16, 11962, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}\n", "(1, 8)\n", "\n" ] } ], "source": [ "# # Local Invocation \n", "\n", "# import torch\n", "# from transformers import BartTokenizer, BartModel\n", "# import tritonclient.http as httpclient\n", "# from tritonclient.utils import *\n", "\n", "# tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')\n", "\n", "# client = httpclient.InferenceServerClient(url=\"localhost:8000\")\n", " \n", "# prompt = \"Hello, my dog is cute\"\n", "\n", "# inputs = tokenizer(prompt)\n", "\n", "# print(inputs)\n", "\n", "# text_obj = np.array(inputs[\"input_ids\"],dtype=np.int32).reshape(1,-1)\n", "# print(text_obj.shape)\n", "# input_text = httpclient.InferInput(\"input_ids\", text_obj.shape, np_to_triton_dtype(text_obj.dtype))\n", "# input_text.set_data_from_numpy(text_obj)\n", "\n", "# print(input_text)" ] }, { "cell_type": "code", "execution_count": 38, "id": "30bee0ad-2ff6-480b-856d-57f6f481175f", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 8)\n" ] } ], "source": [ "# attention_mask_obj = np.array(inputs[\"attention_mask\"], dtype=np.int32).reshape(1,-1)\n", "# print(attention_mask_obj.shape)\n", "# attention_mask = httpclient.InferInput(\"attention_mask\", attention_mask_obj.shape, np_to_triton_dtype(attention_mask_obj.dtype))\n", "# attention_mask.set_data_from_numpy(attention_mask_obj)" ] }, { "cell_type": "code", "execution_count": 39, "id": "1e1bd629-0269-4716-9c73-1b330acc50e8", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# output_img = httpclient.InferRequestedOutput(\"output\")\n", "# output_img" ] }, { "cell_type": "code", "execution_count": 62, "id": "bcb96e4d-510f-4c65-bb9a-8332e089acb1", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[[ 0.55122316 0.83893114 -1.4706688 ... 1.3124448 -0.20466608\n", " 0.23921409]\n", " [ 0.55122286 0.83893126 -1.470669 ... 1.3124448 -0.20466569\n", " 0.23921481]\n", " [ 0.91427237 0.93994033 -1.2426258 ... 0.9183528 -0.18380232\n", " -0.99752015]\n", " ...\n", " [ 0.2560962 0.2253092 0.44698232 ... 0.3447002 0.00871746\n", " 1.5507985 ]\n", " [ 0.20772798 -1.3085785 -1.4295363 ... -0.29977536 0.18280452\n", " 0.46997055]\n", " [-0.48929775 2.5148034 -1.5512955 ... 0.5782852 1.0960634\n", " 0.17355214]]]\n", "\n" ] } ], "source": [ "# result = client.infer(model_name=\"bart_pytorch\", inputs=[input_text, attention_mask], outputs=[output_img])\n", "# output = result.as_numpy(\"output\")\n", "# print(output)\n", "# print(type(output))" ] }, { "cell_type": "code", "execution_count": 19, "id": "87f0c871-cdb6-469d-beeb-d728068924f8", "metadata": { "tags": [] }, "outputs": [], "source": [ "# tokenizer.decode(output[0])" ] }, { "cell_type": "code", "execution_count": 20, "id": "d9b6f24e-b151-4cce-889c-96c9c74ca834", "metadata": { "tags": [] }, "outputs": [], "source": [ "# #Get the tensors back from query response. \n", "# # Read response body\n", "\n", "# header_length_prefix = \"application/vnd.sagemaker-triton.binary+json;json-header-size=\"\n", "\n", "# header_length_str = query_response[\"ContentType\"][len(header_length_prefix) :]\n" ] }, { "cell_type": "code", "execution_count": null, "id": "679119d5-bbbd-4ddb-83ab-7db608b28c3a", "metadata": {}, "outputs": [], "source": [ "# result = httpclient.InferenceServerClient.parse_response_body(\n", "# query_response[\"Body\"].read(), header_length=int(header_length_str)\n", "# )\n", "\n", "# print(result)" ] }, { "cell_type": "code", "execution_count": 7, "id": "3ac12ad1-f450-409f-bb5d-5cf4bf769543", "metadata": { "tags": [] }, "outputs": [], "source": [ "# start = time.time()\n", "# query_response = client.infer(model_name=\"bart_pytorch\", inputs=[input_text, attention_mask], outputs=[output_img])\n", "# print(f\"took {time.time()-start} seconds\")\n", "# print(query_response)" ] }, { "cell_type": "code", "execution_count": 24, "id": "ca3c7965", "metadata": { "tags": [] }, "outputs": [], "source": [ "# from transformers import BartTokenizer, BartModel\n", "\n", "\n", "# def get_tokenizer(model_name):\n", "# tokenizer = BartTokenizer.from_pretrained(model_name)\n", "# return tokenizer\n", "\n", "# def tokenize_text(model_name, text):\n", "# tokenizer = get_tokenizer(model_name)\n", "# tokenized_text = tokenizer(text, return_tensors=\"pt\")\n", "# return tokenized_text.input_ids\n", "\n", "\n", "# def get_text_payload(model_name, text):\n", "# input_ids = tokenize_text(model_name, text)\n", "# payload = {}\n", "# payload[\"inputs\"] = []\n", "# payload[\"inputs\"].append({\"name\": \"input_ids\", \"shape\": input_ids.shape, \"datatype\": \"INT32\", \"data\": input_ids.tolist()})\n", "# return payload\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a91a735c", "metadata": {}, "outputs": [], "source": [ "# sm_client.delete_model(ModelName=sm_model_name)\n", "# sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)\n", "# sm_client.delete_endpoint(EndpointName=endpoint_name)" ] } ], "metadata": { "kernelspec": { "display_name": "conda_amazonei_pytorch_latest_p37", "language": "python", "name": "conda_amazonei_pytorch_latest_p37" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }