# Deploy Flan-T5 XXL on SageMaker

Now, we will deploy the model on SageMaker realtime endpoint, which is also trained on SageMaker with deepspeed on multiple nodes.

In [None]:
import sagemaker
import boto3

sess = sagemaker.Session()
role = sagemaker.get_execution_role()

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


We trained the Flan-T5-XXL, and the model is saved as BF16 format. We will use Huggingface accelerate to speed up the model inference. 

In [None]:
!mkdir deploy_code

In [None]:
%%writefile deploy_code/requirements.txt
accelerate==0.16.0
transformers==4.26.0
bitsandbytes==0.37.0

Now, we use Huggingface accelerate to speed up the model inference (configure "engine" to "Python"). And we will use g5.48xlarge (which has 8 GPUs) to deploy, so option.tensor_parallel_degree is set to 8. Finally, please configure the option.s3url to your model assets' S3 path, the suffix '/' is a must (such as s3://your_bucket/flan-t5-xxl/model/).

In [None]:
%%writefile deploy_code/serving.properties
engine=Python
option.tensor_parallel_degree=8
option.s3url=s3://your_bucket/flan-t5-xxl/model/

In [None]:
%%writefile deploy_code/model.py
from djl_python import Input, Output
import torch
import logging
import math
import os
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer


def load_model(properties):
    tensor_parallel = properties["tensor_parallel_degree"]
    model_location = properties['model_dir']
    if "model_id" in properties:
        model_location = properties['model_id']
    logging.info(f"Loading model in {model_location}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_location)
   
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_location, 
        device_map="balanced_low_0", 
        #load_in_8bit=True
    )
    model.requires_grad_(False)
    model.eval()
    
    return model, tokenizer


model = None
tokenizer = None
generator = None


def handle(inputs: Input):
    global model, tokenizer
    if not model:
        model, tokenizer = load_model(inputs.get_properties())

    if inputs.is_empty():
        return None
    data = inputs.get_as_json()
    
    input_sentences = data["inputs"]
    params = data["parameters"]
    
    # preprocess
    input_ids = tokenizer(input_sentences, return_tensors="pt").input_ids
    # pass inputs with all kwargs in data
    if params is not None:
        outputs = model.generate(input_ids, **params)
    else:
        outputs = model.generate(input_ids)

    # postprocess the prediction
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    result = {"outputs": prediction}
    return Output().add_as_json(result)

We will use LMI (large model inference) container on SageMaker to serve the LLM.

In [None]:
import sagemaker

sess = sagemaker.Session()
region = sess._region_name

inference_image_uri = (
    f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117"
)

print(f"Image going to be used is ---- > {inference_image_uri}")

In [None]:
!rm model-liang.tar.gz
!tar czvf model-liang.tar.gz -C deploy_code .

In [None]:
s3_code_prefix = 'code_flan_t5_LMI_liang'
bucket = sess.default_bucket() 
s3_code_artifact = sess.upload_data("model-liang.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {s3_code_artifact}")

In [None]:
from sagemaker.utils import name_from_base
import boto3
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

model_name = name_from_base(f"flan-t5-xxl-accelerate-LMI")
print(model_name)
print(f"Image going to be used is ---- > {inference_image_uri}")

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image_uri,
        "ModelDataUrl": s3_code_artifact
    },
    
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
endpoint_config_name = f"{model_name}-config-88"
endpoint_name = f"{model_name}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.g5.48xlarge",
            "InitialInstanceCount": 1,
            #"ModelDataDownloadTimeoutInSeconds": 2400,
            "ContainerStartupHealthCheckTimeoutInSeconds": 2400,
        },
    ],
)
endpoint_config_response

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Use the low level boto3 API to generate context.

In [None]:
%%time
import json
import boto3

smr_client = boto3.client("sagemaker-runtime")

prompts = """Summarize the following news article:
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well. Therefore, Peter stayed with her at the hospital for 3 days without leaving.
Summary:
"""

parameters = {
  #"early_stopping": True,
  #"length_penalty": 2.0,
  "max_new_tokens": 50,
  "temperature": 0,
  "min_length": 10,
  "no_repeat_ngram_size": 2,
}


response_model = smr_client.invoke_endpoint(
            EndpointName=endpoint_name,
            Body=json.dumps(
            {
                "inputs": prompts,
                "parameters": parameters
            }
            ),
            ContentType="application/json",
        )

response_model['Body'].read().decode('utf8')