## Democratize Documentation Summarization in Media and Entertainment with Hugging Face Transformers on Amazon SageMaker
### Contents
1. [Introduction](#Introduction)
2. [Model Setup](#Model-Setup)
3. [Model Deployment](#Model-Deployment)
4. [Test Model Endpoint](#Test-the-endpoint)
5. [Clean Up](#Clean-Up)

## Introduction
In this notebook we'll go through an example of deploying a document summarization model from [HuggingFace Hub](https://huggingface.co/models). This model will take as input a long text document and will output a concise summary. We will deploy this model on an [asynchronous endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html) as most summarization use case do not require low latency making the asynchronous endpoint the ideal choice. This notebook was tested on SageMaker Studio running on the Python Data Science kernel and a ml.t3.medium instance

To begin, lets first install the correct dependencies:

In [None]:
%pip install -Uqq sagemaker==2.59.7
%pip install -Uqq boto3

## Model Setup

In this section we will configure the model package that we will deploy. The package includes a `distilbart-cnn-12-6` pretrained model and the source code used to serve it. HuggingFace provides a convinient python function for pulling a model from it's hub. We can invoke this function directly from our inference script, however this is not a good practice as it itroduces a potential point of failure if for some reason our endpoint is not able to communicate with HuggingFace Hub during the deployment or autoscalling operation. In this section we will therefore use git-lfs to download a pretrained model from the Hub and host it in our own S3 bucket

In [None]:
# install git LFS and clone the distilbart model from HF Hub
! curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh |  bash
! apt-get install git-lfs
! git clone https://huggingface.co/sshleifer/distilbart-cnn-12-6

In [None]:
# create the initial model artifact tar file
! cd distilbart-cnn-12-6 && tar --exclude=".*" -cvf  model.tar * && mv model.tar ../model.tar
! rm -r distilbart-cnn-12-6/

In [None]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path

In [None]:
role = sagemaker.get_execution_role()      # execution role for the endpoint
sess = sagemaker.session.Session()         # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()             # bucket to house artifacts
key_prefix = "hugginface_summarization"    # folder within bucket where all artifacts will go

region = sess._region_name

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

In [None]:
if not os.path.exists("code"):
    os.makedirs("code")

Below is our inference code. The `model_fn` is responsible for loading the model while the `transform_fn` will be responsible for the actual inference logic. For simplicity, we'll use the HuggingFace pipeline abstraction to combine the tokenizer with the model

In [None]:
%%writefile code/sum_entrypoint.py
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
from transformers import pipeline
import json

def model_fn(model_dir):
    
    tokenizer = BartTokenizer.from_pretrained(model_dir)
    model = BartForConditionalGeneration.from_pretrained(model_dir)
    nlp=pipeline("summarization", model=model, tokenizer=tokenizer)
    
    return nlp


def transform_fn(nlp, request_body, input_content_type, output_content_type="text/csv"):
    
    if input_content_type == "text/csv":
        result = nlp(request_body, truncation=True)[0]
    
    else:
        raise Exception("content type not supported")
    
    return json.dumps(result)


Now we add the inference code to our tar model package and compress it using gz to produce the final `model.tar.gz` artifact

In [None]:
# add the inference code to the model artifact, zip the archive, delete the initial tar. The compression step could take 5 to 10 minutes as this is a large model
! tar rvf  model.tar  code/sum_entrypoint.py
! gzip  model.tar model.tar.gz
! rm model.tar

The final step before we're ready to deploy is to upload the model artifact and fetch the ECR uri for the huggingface inference container

In [None]:
s3_model_data = sess.upload_data("model.tar.gz", bucket, key_prefix)
inference_image_uri = image_uris.retrieve(
            "huggingface",
            region,
            version="4.6.1",
            py_version="py36",
            instance_type="ml.m5.xlarge",
            image_scope="inference",
            base_framework_version="pytorch1.8.1"
        )

## Model Deployment
The model is now ready to be deployed. The first step is to create a SageMaker Model Resource

In [None]:
model_name = "document-summarization"
create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image_uri,
        "ModelDataUrl": s3_model_data,
        "Environment": {
            "SAGEMAKER_PROGRAM": "sum_entrypoint.py",
            "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
            "SAGEMAKER_REGION": region,
        },
    },
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

Now we create the endpoint configuration with asynch inference enabled. We can optionally provide an SNS topic to send a comletion notification for each inference request

In [None]:
endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.m5.xlarge",
            "InitialInstanceCount": 1,
        }
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": f"s3://{bucket}/{key_prefix}/async-output",
            # Optionally specify Amazon SNS topics
            # "NotificationConfig": {
            #   "SuccessTopic": "arn:aws:sns:us-east-2:123456789012:MyTopic",
            #   "ErrorTopic": "arn:aws:sns:us-east-2:123456789012:MyTopic",
            # }
        },
        "ClientConfig": {"MaxConcurrentInvocationsPerInstance": 4},
    },
)

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

Note you may need to add ECR permissions such as `AmazonEC2ContainerRegistryReadOnly` to your execution role if your endpoint fails to create

In [None]:
waiter = sm_client.get_waiter("endpoint_in_service")
print("Waiting for endpoint to create...")
waiter.wait(EndpointName=endpoint_name)

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
print(f"Endpoint Status: {resp['EndpointStatus']}")

## Test the endpoint

We'll upload some sample articles from the data directory to test our endpoint. Feel free to add your own text documents as well

In [None]:
!aws s3 cp --recursive ./data s3://{bucket}/{key_prefix}

In [None]:
# invoke the endpoint with one of the uploaded documents
response = smr_client.invoke_endpoint_async(
    EndpointName=endpoint_name, 
    InputLocation=f"s3://{bucket}/{key_prefix}/article1.txt",
    ContentType="text/csv"
)
output_key = "/".join(response["OutputLocation"].split("/")[3:])

In [None]:
# wait until the results become available

from botocore.errorfactory import ClientError

total_sleep = 0

while True:
    try: 
        # check if result is ready 
        s3_client.head_object(Bucket=bucket, Key=output_key)
        result = sess.read_s3_file(bucket=bucket, key_prefix=output_key)
        print("Results are ready")
        break
    except ClientError as e:
        if e.response["Error"]["Code"] == "404":
            if total_sleep < 60:
                print("Results are not yet ready. Sleeping for 5s")
                time.sleep(5)
                total_sleep+=5
                continue
            else:
                print("Been waiting for 60s, terminating the poll. Check the endpoint logs to see if there are any issues")
                break
    else:
        print(f"Unexpected error encountered. Please review the Cloud Watch logs for {endpoint_name}")
        break

In [None]:
print(result)

## Clean Up

In [None]:
sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=model_name)
! rm model.tar.gz
! aws s3 rm s3://{bucket}/{key_prefix} --recursive > /dev/null