# BART Large model deployment on Amazon SageMaker Multi-model endpoints (MME) with GPU 



Amazon SageMaker multi-model endpoints(MME) provide a scalable and cost-effective way to deploy large number of deep learning models. Previously, customers had limited options to deploy 100s of deep learning models that need accelerated compute with GPUs. Now customers can deploy 1000s of deep learning models behind one SageMaker endpoint. Now, MME will run multiple models on a GPU, share GPU instances behind an endpoint across multiple models and dynamically load/unload models based on the incoming traffic. With this, customers can significantly save cost and achieve best price performance.



<div class="alert alert-info"> üí° <strong> Note </strong>
This notebook was tested with the `conda_python3` kernel on an Amazon SageMaker notebook instance of type `g5.xlarge`.
</div>

In this notebook, we will walk you through how to use NVIDIA Triton Inference Server on Amazon SageMaker MME with GPU feature to deploy a **BART** NLP model for **Translation**. 

## Installs

Installs the dependencies required to package the model and run inferences using Triton server. Update SageMaker, boto3, awscli etc

In [1]:
!pip install -qU pip awscli boto3 sagemaker
!pip install nvidia-pyindex --quiet
!pip install tritonclient[http] --quiet
!pip install transformers[sentencepiece] --quiet
!pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
!pip install transformers --quiet

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com, https://pypi.ngc.nvidia.com, https://download.pytorch.org/whl/cu116


## Imports and variables

In [2]:
import boto3, json, sagemaker, time
from sagemaker import get_execution_role
import numpy as np
import os
import json
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# sagemaker variables
role = get_execution_role()
sm_client = boto3.client(service_name="sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")
sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
s3_client = boto3.client('s3')
bucket = sagemaker.Session().default_bucket()
prefix = "bart"

# account mapping for SageMaker MME Triton Image
account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}

region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
triton_image_uri = (
    "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.12-py3".format(
        account_id=account_id_map[region], region=region, base=base
    )
)
print(triton_image_uri)

301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:22.12-py3


## Workflow Overview

This section presents overview of main steps for preparing a BART Pytorch model (served using Python backend) using Triton Inference Server.
### 1. Generate Model Artifacts



#### BART PyTorch Model

In case of BART HuggingFace PyTorch Model, since we are serving it using Triton's [python backend](https://github.com/triton-inference-server/python_backend#usage) we have python script [model.py](./workspace/model.py) which implements all the logic to initialize the BART model and execute inference for the translation task.

### 2. Build Model Respository

Using Triton on SageMaker requires us to first set up a [model repository](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_repository.md) folder containing the models we want to serve. For each model we need to create a model directory consisting of the model artifact and define config.pbtxt file to specify [model configuration](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md) which Triton uses to load and serve the model. 



#### BART Python Backend Model

Model repository structure for BART Model.

```
bart_pytorch
‚îú‚îÄ‚îÄ 1
‚îÇ   ‚îî‚îÄ‚îÄ model.py
‚îî‚îÄ‚îÄ config.pbtxt
```


Next we set up the BART PyTorch Python Backend Model in the model repository:

In [13]:
!pwd

/home/ec2-user/SageMaker


In [14]:
!mkdir -p model_repository/bart_pytorch/1
!cp workspace/model.py model_repository/bart_pytorch/1/

mkdir: cannot create directory ‚Äòmodel_repository/bart_pytorch‚Äô: Permission denied
cp: cannot stat ‚Äòworkspace/model.py‚Äô: No such file or directory


##### Create Conda Environment for Dependencies

For serving the HuggingFace BART PyTorch Model using Triton's Python backend we have PyTorch and HuggingFace transformers as dependencies.

We follow the instructions from the [Triton documentation for packaging dependencies](https://github.com/triton-inference-server/python_backend#2-packaging-the-conda-environment) to be used in the python backend as conda env tar file. Running the bash script [create_hf_env.sh]('./workspace/create_hf_env.sh') creates the conda environment containing PyTorch and HuggingFace transformers, packages it as tar file and then we move it into the bart-pytorch model directory. This can take a few minutes.

In [13]:
!bash workspace/create_hf_env.sh
!mv hf_env.tar.gz model_repository/bart_pytorch/

Collecting package metadata (current_repodata.json): done
Solving environment: done


  current version: 22.9.0
  latest version: 23.1.0

Please update conda by running

    $ conda update -n base -c conda-forge conda



## Package Plan ##

  environment location: /home/ec2-user/anaconda3/envs/hf_env

  added / updated specs:
    - python=3.8


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    openssl-3.1.0              |       h0b41bf4_0         2.5 MB  conda-forge
    wheel-0.40.0               |     pyhd8ed1ab_0          54 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         2.6 MB

The following NEW packages will be INSTALLED:

  _libgcc_mutex      conda-forge/linux-64::_libgcc_mutex-0.1-conda_forge None
  _openmp_mutex      conda-forge/linux-64::_openmp_mutex-4.5-2_gnu None
  bzip2              cond

After creating the tar file from the conda environment and placing it in model folder, you need to tell Python backend to use that environment for your model. We do this by including the lines below in the model `config.pbtxt` file:

```
parameters: {
  key: "EXECUTION_ENV_PATH",
  value: {string_value: "$$TRITON_MODEL_DIRECTORY/hf_env.tar.gz"}
}
```
Here, `$$TRITON_MODEL_DIRECTORY` helps provide environment path relative to the model folder in model repository and is resolved to `$pwd/model_repository/bart_pytorch`. Finally `hf_env.tar.gz` is the name we gave to our conda env file.

Now we are ready to define the config file for bart pytorch model being served through Triton's Python Backend:

In [14]:
%%writefile model_repository/bart_pytorch/config.pbtxt
name: "bart_pytorch"
backend: "python"
max_batch_size: 8
input: [
    {
        name: "input_ids"
        data_type: TYPE_INT32
        dims: [ -1 ]
    },
    {
        name: "attention_mask"
        data_type: TYPE_INT32
        dims: [ -1 ]
    }
]
output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
  }
]
instance_group {
  count: 1
  kind: KIND_GPU
}
dynamic_batching {
}
parameters: {
  key: "EXECUTION_ENV_PATH",
  value: {string_value: "$$TRITON_MODEL_DIRECTORY/hf_env.tar.gz"}
}

Overwriting model_repository/bart_pytorch/config.pbtxt


### 3. Package models and upload to S3

Next, we will package our model as `*.tar.gz` files for uploading to S3. 

In [86]:
! pwd

/home/ec2-user/SageMaker


In [100]:
!tar -C BART-Triton-PyTorch/model_repository/ -czf BART-Triton-PyTorch/bart_pytorch_7.tar.gz bart_pytorch
model_uri_bart_pytorch = sagemaker_session.upload_data(path="BART-Triton-PyTorch/bart_pytorch_7.tar.gz", key_prefix=prefix)

### 4. Create SageMaker Endpoint

Now that we have uploaded the model artifacts to S3, we can create a SageMaker endpoint.

#### Define the serving container
In the container definition, define the `ModelDataUrl` to specify the S3 directory that contains all the models that SageMaker multi-model endpoint will use to load and serve predictions. Set `Mode` to `MultiModel` to indicate SageMaker would create the endpoint with MME container specifications. We set the container with an image that supports deploying multi-model endpoints with GPU

In [71]:
print(model_uri_bart_pytorch)
model_data_url = f"s3://{bucket}/{prefix}/"
print(model_data_url)

container = {
    "Image": triton_image_uri,
    "ModelDataUrl": model_data_url,
    "Mode": "MultiModel", 
}
print(container)

s3://sagemaker-us-west-2-757967535041/bart/bart_pytorch.tar.gz
s3://sagemaker-us-west-2-757967535041/bart/
{'Image': '301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:22.12-py3', 'ModelDataUrl': 's3://sagemaker-us-west-2-757967535041/bart/', 'Mode': 'MultiModel'}


#### Create a multi-model object

Once the image, data location are set we create the model using `create_model` by specifying the `ModelName` and the Container definition

In [59]:
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
sm_model_name = f"{prefix}-{ts}"
print(sm_model_name)

bart-2023-03-17-21-19-34


In [60]:
create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Model Arn: arn:aws:sagemaker:us-west-2:757967535041:model/bart-2023-03-17-21-19-34


#### Define configuration for the multi-model endpoint

Using the model above, we create an [endpoint configuration](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html) where we can specify the type and number of instances we want in the endpoint. Here we are deploying to `g5.2xlarge` NVIDIA GPU instance.

In [61]:
endpoint_config_name = f"{prefix}-epc-{ts}"

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.g5.2xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])


Endpoint Config Arn: arn:aws:sagemaker:us-west-2:757967535041:endpoint-config/bart-epc-2023-03-17-21-19-34


#### Create SageMaker Endpoint

Using the above endpoint configuration we create a new sagemaker endpoint and wait for the deployment to finish. The status will change to **InService** once the deployment is successful.

In [62]:
endpoint_name = f"{prefix}-ep-{ts}"

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])
print("endpointname: " + endpoint_name)

Endpoint Arn: arn:aws:sagemaker:us-west-2:757967535041:endpoint/bart-ep-2023-03-17-21-19-34
endpointname: bart-ep-2023-03-17-21-19-34


In [63]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: InService
Arn: arn:aws:sagemaker:us-west-2:757967535041:endpoint/bart-ep-2023-03-17-21-19-34
Status: InService


### 5. Run Inference

Once we have the endpoint running we can use some sample raw data to do an inference using JSON as the payload format. For the inference request format, Triton uses the KFServing community standard [inference protocols](https://github.com/triton-inference-server/server/blob/main/docs/protocol/README.md).

#### Add utility methods for preparing JSON request payload



We'll use the following utility methods to convert our inference request for BART models into a json payload.

In [6]:
#helper functions
import tritonclient.http as httpclient
from transformers import BartTokenizer, BartModel
from tritonclient.utils import *

def get_tokenizer(model_name):
    tokenizer = BartTokenizer.from_pretrained(model_name)
    return tokenizer

# def tokenize_text(model_name, text):
#     tokenizer = get_tokenizer(model_name)
#     tokenized_text = tokenizer(text, padding=True, return_tensors="pt")
#     #tokenized_text = tokenizer(text)
#     return tokenized_text


def tokenize_text(model_name, text):
    tokenizer = get_tokenizer(model_name)
    tokenized_text = tokenizer(text, padding=True, return_tensors="np")
    return tokenized_text.input_ids, tokenized_text.attention_mask

#V1
# def get_text_payload(model_name, text):
    
#     inputs = []
#     outputs = []
    
#     inputs = tokenize_text(model_name, text)
    
#     text_obj = np.array(inputs["input_ids"],dtype=np.int32).reshape(1,-1)
#     print(text_obj.shape)
#     input_text = httpclient.InferInput("input_ids", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
#     input_text.set_data_from_numpy(text_obj)
#     print(input_text)

#     attention_mask_obj = np.array(inputs["attention_mask"], dtype=np.int32).reshape(1,-1)
#     print(attention_mask_obj.shape)
#     attention_mask = httpclient.InferInput("attention_mask", attention_mask_obj.shape, np_to_triton_dtype(attention_mask_obj.dtype))
#     attention_mask.set_data_from_numpy(attention_mask_obj)
#     print(attention_mask)
    
#     inputs=[input_text, attention_mask]
#     return inputs

#v2 

def get_text_payload(model_name, text):
    input_ids, attention_mask = tokenize_text(model_name, text)
    payload = {}
    payload["inputs"] = []
    payload["inputs"].append({"name": "input_ids", "shape": input_ids.shape, "datatype": "INT32", "data": input_ids.tolist()})
    payload["inputs"].append({"name": "attention_mask", "shape": attention_mask.shape, "datatype": "INT32", "data": attention_mask.tolist()})
    
    return payload

# text_input = "Hello, my dog is cute"
# bart_payload = get_text_payload('facebook/bart-large', text_input)

# print("bart_payload is", bart_payload)
# print(" payload type is", type(bart_payload))
    

In [4]:
endpoint_name = "bart-ep-2023-03-17-21-19-34"
print(endpoint_name)

sm_client.describe_endpoint(EndpointName=endpoint_name)

texts = ["Hello, my dog is cute"]
batch_size = len(texts)
print(batch_size)

bart-ep-2023-03-17-21-19-34
1


In [7]:
bart_payload = get_text_payload('facebook/bart-large', texts)
print(bart_payload)

{'inputs': [{'name': 'input_ids', 'shape': (1, 8), 'datatype': 'INT32', 'data': [[0, 31414, 6, 127, 2335, 16, 11962, 2]]}, {'name': 'attention_mask', 'shape': (1, 8), 'datatype': 'INT32', 'data': [[1, 1, 1, 1, 1, 1, 1, 1]]}]}


In [14]:
%time

response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/octet-stream",
    Body=json.dumps(bart_payload),
    TargetModel="bart_pytorch_7.tar.gz",
)
print(response)

CPU times: user 1 ¬µs, sys: 0 ns, total: 1 ¬µs
Wall time: 3.81 ¬µs
{'ResponseMetadata': {'RequestId': '690c8bd4-5ccf-461e-a85b-b051de3845ea', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '690c8bd4-5ccf-461e-a85b-b051de3845ea', 'x-amzn-invoked-production-variant': 'AllTraffic', 'date': 'Fri, 17 Mar 2023 23:00:06 GMT', 'content-type': 'application/json', 'content-length': '162666'}, 'RetryAttempts': 0}, 'ContentType': 'application/json', 'InvokedProductionVariant': 'AllTraffic', 'Body': <botocore.response.StreamingBody object at 0x7f04cc7a03d0>}


In [9]:
response_body = json.loads(response["Body"].read().decode("utf8"))
print(type(response_body))

<class 'dict'>


In [10]:
response_body

{'model_name': 'ccdfac5647468461e480c1c55bba575f',
 'model_version': '1',
 'outputs': [{'name': 'output',
   'datatype': 'FP32',
   'shape': [1, 8, 1024],
   'data': [0.5512231588363647,
    0.838931143283844,
    -1.4706687927246094,
    0.36186015605926514,
    -0.16138087213039398,
    -0.7445544600486755,
    -0.5286368131637573,
    -0.9802274107933044,
    -1.5751498937606812,
    -0.1964043527841568,
    -0.6759628057479858,
    -1.1872944831848145,
    -0.9507772922515869,
    -0.4348897933959961,
    0.2833757698535919,
    0.3985210955142975,
    0.5549306869506836,
    0.05500718951225281,
    0.1372915804386139,
    0.20988517999649048,
    4.532204627990723,
    -0.1279691457748413,
    1.0849722623825073,
    -0.8920677304267883,
    -0.7142594456672668,
    0.24808241426944733,
    -1.0593713521957397,
    -0.30108046531677246,
    0.920939564704895,
    -0.11048457026481628,
    0.3940015137195587,
    -0.6587729454040527,
    0.9290228486061096,
    -0.0422490872442722

In [37]:
# # Local Invocation 

# import torch
# from transformers import BartTokenizer, BartModel
# import tritonclient.http as httpclient
# from tritonclient.utils import *

# tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

# client = httpclient.InferenceServerClient(url="localhost:8000")
 
# prompt = "Hello, my dog is cute"

# inputs = tokenizer(prompt)

# print(inputs)

# text_obj = np.array(inputs["input_ids"],dtype=np.int32).reshape(1,-1)
# print(text_obj.shape)
# input_text = httpclient.InferInput("input_ids", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
# input_text.set_data_from_numpy(text_obj)

# print(input_text)

{'input_ids': [0, 31414, 6, 127, 2335, 16, 11962, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
(1, 8)
<tritonclient.http.InferInput object at 0x7efd64473a30>


In [38]:
# attention_mask_obj = np.array(inputs["attention_mask"], dtype=np.int32).reshape(1,-1)
# print(attention_mask_obj.shape)
# attention_mask = httpclient.InferInput("attention_mask", attention_mask_obj.shape, np_to_triton_dtype(attention_mask_obj.dtype))
# attention_mask.set_data_from_numpy(attention_mask_obj)

(1, 8)


In [39]:
# output_img = httpclient.InferRequestedOutput("output")
# output_img

<tritonclient.http.InferRequestedOutput at 0x7efd6448f2b0>

In [62]:
# result = client.infer(model_name="bart_pytorch", inputs=[input_text, attention_mask], outputs=[output_img])
# output = result.as_numpy("output")
# print(output)
# print(type(output))

[[[ 0.55122316  0.83893114 -1.4706688  ...  1.3124448  -0.20466608
    0.23921409]
  [ 0.55122286  0.83893126 -1.470669   ...  1.3124448  -0.20466569
    0.23921481]
  [ 0.91427237  0.93994033 -1.2426258  ...  0.9183528  -0.18380232
   -0.99752015]
  ...
  [ 0.2560962   0.2253092   0.44698232 ...  0.3447002   0.00871746
    1.5507985 ]
  [ 0.20772798 -1.3085785  -1.4295363  ... -0.29977536  0.18280452
    0.46997055]
  [-0.48929775  2.5148034  -1.5512955  ...  0.5782852   1.0960634
    0.17355214]]]
<class 'numpy.ndarray'>


In [19]:
# tokenizer.decode(output[0])

In [20]:
# #Get the tensors back from query response. 
# # Read response body

# header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="

# header_length_str = query_response["ContentType"][len(header_length_prefix) :]


In [None]:
# result = httpclient.InferenceServerClient.parse_response_body(
#     query_response["Body"].read(), header_length=int(header_length_str)
# )

# print(result)

In [7]:
# start = time.time()
# query_response = client.infer(model_name="bart_pytorch", inputs=[input_text, attention_mask], outputs=[output_img])
# print(f"took {time.time()-start} seconds")
# print(query_response)

In [24]:
# from transformers import BartTokenizer, BartModel


# def get_tokenizer(model_name):
#     tokenizer = BartTokenizer.from_pretrained(model_name)
#     return tokenizer

# def tokenize_text(model_name, text):
#     tokenizer = get_tokenizer(model_name)
#     tokenized_text = tokenizer(text, return_tensors="pt")
#     return tokenized_text.input_ids


# def get_text_payload(model_name, text):
#     input_ids = tokenize_text(model_name, text)
#     payload = {}
#     payload["inputs"] = []
#     payload["inputs"].append({"name": "input_ids", "shape": input_ids.shape, "datatype": "INT32", "data": input_ids.tolist()})
#     return payload


In [None]:
# sm_client.delete_model(ModelName=sm_model_name)
# sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
# sm_client.delete_endpoint(EndpointName=endpoint_name)