# ByteTrack Inference with Amazon SageMaker

This notebook will demonstrate how to create an endpoint for real time inference with the trained FairMOT model.

## 1. SageMaker Initialization 
First we upgrade SageMaker to the latest version. If your notebook is already using latest Sagemaker 2.x API, you may skip the next cell.

In [None]:
! pip install --upgrade pip
! python3 -m pip install --upgrade sagemaker
! pip install cython_bbox

In [None]:
import boto3
import json
import time
import numpy as np

import sagemaker
from sagemaker import get_execution_role
from sagemaker.pytorch.model import PyTorchModel

from time import strftime,gmtime

role = (
    get_execution_role()
)  # provide a pre-existing role ARN as an alternative to creating a new role
print(f"SageMaker Execution Role:{role}")

client = boto3.client('sts')
account = client.get_caller_identity()['Account']
print(f'AWS account:{account}')

session = boto3.session.Session()
sm_session = sagemaker.session.Session()
aws_region = session.region_name
print(f"AWS region:{aws_region}")

## 2. Deploy an Asynchronous Inference Endpoint

You need to complete training job on [bytetrack-training.ipynb](bytetrack-training.ipynb) before running the following steps.

In [None]:
%store -r s3_model_uri

#### Prepare model.tag.gz

To use scripts related to ByteTrack on the endpoint, we need to put tracking script and model into the same folder and compress the folder as the model.tar.gz, and then upload it to S3 bucket for creating a model. The following is the structure of model.tar.gz:
<img src="img/async_inference_model.png"></img>

In [None]:
!aws s3 cp $s3_model_uri ./sagemaker-serving-async/model.tar.gz
!cd sagemaker-serving-async && tar -xvf model.tar.gz && rm model.tar.gz

In [None]:
%%writefile download_tracking_async_inference.sh
git clone --filter=blob:none --no-checkout --depth 1 --sparse https://github.com/ifzhang/ByteTrack.git && \
cd ByteTrack && \
git sparse-checkout set yolox && \
git checkout && \
cd ..
cp -r ByteTrack/yolox sagemaker-serving-async/code/
cp container-batch-inference/byte_tracker.py sagemaker-serving-async/code/yolox/tracker/
sudo rm -r ByteTrack

In [None]:
!bash download_tracking_async_inference.sh

In [None]:
!cd sagemaker-serving-async && tar -cvzf model.tar.gz * && aws s3 cp model.tar.gz $s3_model_uri && rm model.tar.gz

In order to handle the large video file, we need to explicitly set the payload size and response timeout with environment variables in `PyTorchModel`.

In [None]:
pytorch_model = PyTorchModel(
    model_data=s3_model_uri,
    role=role,
    entry_point="inference.py",
    framework_version="1.7.1",
    sagemaker_session=sm_session,
    py_version="py3",
    env={
        'TS_MAX_REQUEST_SIZE': '1000000000', #default max request size is 6 Mb for torchserve, need to update it to support the 1GB input payload
        'TS_MAX_RESPONSE_SIZE': '1000000000',
        'TS_DEFAULT_RESPONSE_TIMEOUT': '900', # max timeout is 15mins (900 seconds)
        'INPUT_WIDTH': '1440',
        'INPUT_HEIGHT': '800'
    }
)

pytorch_model.create(
    instance_type="ml.p3.2xlarge",
)

In [None]:
async_endpoint_config_name = f"YoloxAsyncEndpointConfig-{strftime('%Y-%m-%d-%H-%M-%S', gmtime())}"

bucket_name = <bucket_name> # S3 Bucket name
prefix = <prefix> # Prefix

s3_output_path = f"s3://{bucket_name}/{prefix}/output/async"

sagemaker_session = sagemaker.Session()
boto_session = sagemaker_session.boto_session
sagemaker_client = boto_session.client('sagemaker')

create_endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName=async_endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": pytorch_model.name,
            "InstanceType": "ml.p3.2xlarge",
            "InitialInstanceCount": 1
        }
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": s3_output_path,
            #  Optionally specify Amazon SNS topics
            #"NotificationConfig": {
            #  "SuccessTopic": success_topic,
            #  "ErrorTopic": error_topic,
            #}
        },
        "ClientConfig": {
            "MaxConcurrentInvocationsPerInstance": 2
        }
    }
)
print(f"Created EndpointConfig: {create_endpoint_config_response['EndpointConfigArn']}")

In [None]:
async_endpoint_name = f"bytetrack-{strftime('%Y-%m-%d-%H-%M-%S', gmtime())}"
create_endpoint_response = sagemaker_client.create_endpoint(
    EndpointName=async_endpoint_name,
    EndpointConfigName=async_endpoint_config_name
)
print(f"Creating Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
waiter = boto3.client('sagemaker').get_waiter('endpoint_in_service')
print("Waiting for endpoint to create...")
waiter.wait(EndpointName=async_endpoint_name)
resp = sagemaker_client.describe_endpoint(EndpointName=async_endpoint_name)
print(f"Endpoint Status: {resp['EndpointStatus']}")

## 3. Test Asynchronous Inference Endpoint

In [None]:
data_path = "datasets/MOT16-03.mp4"
input_s3_path = f"s3://{bucket_name}/{prefix}/inputs/MOT16-03.mp4"
!mkdir datasets
!wget https://raw.githubusercontent.com/ifzhang/FairMOT/master/videos/MOT16-03.mp4 -O $data_path
!aws s3 cp $data_path $input_s3_path

In [None]:
sm_runtime = boto3.Session().client("sagemaker-runtime")

response = sm_runtime.invoke_endpoint_async(
    EndpointName=async_endpoint_name, 
    InputLocation=input_s3_path
)
output_location = response['OutputLocation']
print(f"OutputLocation: {output_location}")

In [None]:
from botocore.exceptions import ClientError
import urllib
import sys

def get_output(output_location):
    output_url = urllib.parse.urlparse(output_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    while True:
        try:
            return sm_session.read_s3_file(bucket=output_url.netloc, key_prefix=output_url.path[1:])
        except ClientError as e:
            if e.response['Error']['Code'] == 'NoSuchKey':
                print("waiting for output...")
                time.sleep(2)
                continue
            raise

In [None]:
output = get_output(output_location)
print(f"Output size in bytes: {((sys.getsizeof(output)))}")

In [None]:
tracking_res = "./datasets/tracking_res.txt"
!aws s3 cp $output_location $tracking_res

### Visualize the tracking result

In [None]:
import cv2
import time
import os.path as osp
import os
import io
from yolox.utils.visualize import plot_tracking

with open(tracking_res, 'r') as f:
    tracking_res = json.load(f)

frame_dict = {}
for track in tracking_res:
    track = track.split(',')
    track = list(map(float, track))
    frame_id = track[0]
    bboxes = track[1:]
    
    if frame_id not in frame_dict:
        frame_dict[frame_id] = [bboxes]
    else:
        frame_dict[frame_id].append(bboxes)

In [None]:
from yolox.tracking_utils.timer import Timer

cap = cv2.VideoCapture(data_path)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float

fps = cap.get(cv2.CAP_PROP_FPS)
save_path = "datasets/tracking_res.mp4"

vid_writer = cv2.VideoWriter(
    save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)

frame_id = 0

while True:
    ret_val, frame = cap.read()
    if ret_val:
        if frame_id in frame_dict:
            bboxes = frame_dict[frame_id]
            online_tlwhs = []
            online_ids = []
            online_scores = []

            for bbox in bboxes:
                online_tlwhs.append(bbox[1:5])
                online_ids.append(bbox[0])
                online_scores.append(bbox[5])

            online_im = plot_tracking(
                frame, online_tlwhs, online_ids, frame_id=frame_id + 1, fps=25
            )
        else:
            online_im = frame
            
        vid_writer.write(online_im)
    else:
        break
    frame_id += 1

cap.release()
vid_writer.release()

You can download `datasets/tracking_res.mp4` and check the visualized result.