# Fairmot Model Inference in Amazon SageMaker

This notebook will demonstrate how to create an endpoint for real time inference with the trained FairMOT model.

## 1. SageMaker Initialization 
First we upgrade SageMaker to the latest version. If your notebook is already using latest Sagemaker 2.x API, you may skip the next cell.

In [None]:
! pip install --upgrade pip
! python3 -m pip install --upgrade sagemaker

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role

role = (
 get_execution_role()
) # provide a pre-existing role ARN as an alternative to creating a new role
print(f"SageMaker Execution Role:{role}")

client = boto3.client('sts')
account = client.get_caller_identity()['Account']
print(f'AWS account:{account}')

session = boto3.session.Session()
aws_region = session.region_name
print(f"AWS region:{aws_region}")

container_name = "container-serving"

## 2. Build and Push Amazon SageMaker Serving Container Images

For this step, the [IAM Role](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles.html) attached to this notebook instance needs full access to [Amazon ECR service](https://aws.amazon.com/ecr/). We use the implementation of [FairMOT](https://github.com/ifzhang/FairMOT) to create our own container.

### 2.1 Docker Environment Preparation

Because the volume size of container may be larger than the available size in root directory of the notebook instance, we need to put the directory of docker data into the ```/home/ec2-user/SageMaker/docker``` directory.

By default, the root directory of docker is set as ```/var/lib/docker/```. We need to change the directory of docker to ```/home/ec2-user/SageMaker/docker```.

In [None]:
!bash ./prepare-docker.sh

### 2.2 Build and Push FairMOT Serving Container Image

Use [`./container-serving/build_tools/build_and_push.sh`](./container-serving/build_tools/build_and_push.sh) script to build the [FairMOT](https://github.com/ifzhang/FairMOT) serving container image and push it to Amazon ECR. 

In [None]:
!cat ./{container_name}/build_tools/build_and_push.sh

Using your *AWS region* as argument, run the cell below.

In [None]:
%%time
! ./{container_name}/build_tools/build_and_push.sh {aws_region}

In [None]:
fairmot_image = f"{account}.dkr.ecr.{aws_region}.amazonaws.com/fairmot-sagemaker:pytorch1.8-serving"

## 3. Create Inference Endpoint

### 3.1 Define Amazon SageMaker Model
Next, we define an Amazon SageMaker Model that we will serve from an Amazon SageMaker Endpoint. 

In [None]:
model_name = "fairmot-model-1" # set the name of the model, like fairmot-model-1

You can get the S3 URI of the trained model in Training job console once the training job gets finished (note that the training job is launched in [`fairmot-training.ipynb`](fairmot-training.ipynb)), and then set `ModelDataUrl` to the S3 URI of the trained model.

In [None]:
# Restore the s3 uri of trained model
%store -r s3_model_uri
#s3_model_uri="s3://{bucket_name}/{predix_model}/model.tar.gz" you can define the model URI manually.

serving_container_def = {
 'Image': fairmot_image,
 'ModelDataUrl': s3_model_uri,
 'Mode': 'SingleModel',
 'Environment': {
 'SM_MODEL_DIR' : '/opt/ml/model',
 }
}

sagemaker_session = sagemaker.session.Session(boto_session=session)
create_model_response = sagemaker_session.create_model(name=model_name, 
 role=role, 
 container_defs=serving_container_def)

 ### 3.2 Create Endpoint Configuration
 Next, we set the name of the Amaozn SageMaker hosted service endpoint configuration.

In [None]:
endpoint_config_name = f"{model_name}-endpoint-config"
print(endpoint_config_name)

Then create the Amazon SageMaker hosted service endpoint configuration that uses one instance of `ml.p3.2xlarge` to serve the model.

In [None]:
epc = sagemaker_session.create_endpoint_config(
 name=endpoint_config_name,
 model_name=model_name,
 initial_instance_count=1,
 instance_type="ml.p3.2xlarge",
)
print(epc)

Next we specify the Amazon SageMaker endpoint name for the endpoint used to serve the model.

In [None]:
endpoint_name = f"{model_name}-endpoint"
print(endpoint_name)

### 3.3 Create Endpoint
In this step, we create the Amazon SageMaker endpoint using the endpoint configuration we created above.

In [None]:
ep = sagemaker_session.create_endpoint(
 endpoint_name=endpoint_name, config_name=endpoint_config_name, wait=True
)
print(ep)

## 4. Test Endpoint
### 4.1 Visualization Helper Functions
Draw the bounding box and ID for each tracked object in the raw frames.

In [None]:
def get_color(idx):
 idx = idx * 3
 color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)

 return color

def draw_res(tracker_dict, frame, frame_id, image_w):
 i = 0
 indexIDs = []
 boxes = []
 person_num = 0
 conf = None
 text_scale = max(1, image_w / 1600.)
 text_thickness = 1
 line_thickness = max(1, int(image_w/ 500.))
 for track_id, tlwh in tracker_dict.items():
 indexIDs.append(track_id)
 x1, y1, w, h = tlwh
 intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
 color = get_color(abs(int(track_id)))
 cv2.rectangle(frame, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
 cv2.putText(frame, str(track_id), (intbox[0], intbox[1] + 30), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 0),thickness=1)
 cv2.putText(frame, 'frame:{}'.format(frame_id), (int(25), int(25)),0, text_scale, (0,0,255),1)
 i += 1
 return frame

### 4.2 Invoke endpoint
Next, we use [MOT16-03](https://motchallenge.net/sequenceVideos/MOT16-04-raw.webm) from MOT challenge to test our endpoint. We create a directory `datasets` in the root directory of this project for saving the processed result, and then download it to `datasets` directory with MP4 format from [FairMOT](https://raw.githubusercontent.com/ifzhang/FairMOT/master/videos/MOT16-03.mp4).

In [None]:
!mkdir -p datasets
!wget https://raw.githubusercontent.com/ifzhang/FairMOT/master/videos/MOT16-03.mp4 -O datasets/MOT16-03.mp4

After preparing the test data, we invoke the endpoint to run the real time inferece on the test video. It takes about 150 seconds to complete all of the inference.

In [None]:
import boto3
import base64
import json
import cv2
import time
import os

client = boto3.client("sagemaker-runtime")

data_path = "datasets/MOT16-03.mp4" 
cap = cv2.VideoCapture(data_path)
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

fourcc = cv2.VideoWriter_fourcc(*'MP4V')
file_path = os.path.join('datasets', 'test.mp4')
out = cv2.VideoWriter(file_path, fourcc, 25, (frame_w, frame_h))

processing_time = 0
frame_id = 0

while True:
 ret, frame = cap.read()
 if ret != True:
 break
 
 frame_path = f'datasets/{frame_id}.jpg'
 cv2.imwrite(frame_path, frame)
 
 with open(frame_path, "rb") as image_file:
 img_data = base64.b64encode(image_file.read())
 data = {"frame_id": frame_id}
 data["frame_data"] = img_data.decode("utf-8")
 if frame_id == 0:
 data["frame_w"] = frame_w
 data["frame_h"] = frame_h
 data["batch_size"] = 1 # for multiple stream
 body = json.dumps(data).encode("utf-8")
 
 os.remove(frame_path)
 request_time=time.time()
 response = client.invoke_endpoint(
 EndpointName=endpoint_name, ContentType="application/json", Accept="application/json", Body=body
 )
 if frame_id > 0:
 processing_time += (time.time() - request_time)
 print(f'frame-{frame_id} Processing time: {(time.time() - request_time)}')
 body = response["Body"].read()
 msg = body.decode("utf-8")
 data = json.loads(msg)
 frame_res = draw_res(data[0], frame, frame_id, frame_w)
 out.write(frame_res)
 frame_id += 1

out.release()
cap.release()
print('average processing time: ', processing_time/frame_id)

The response from the endpoint includes the bounding box information and ID for each person. You can download the processed video from `datasets` to check from the local instance.

## Delete SageMaker Endpoint, Endpoint Config and Model
**If you are done testing, delete the deployed Amazon SageMaker endpoint, endpoint config, and the model below.**

In [None]:
sagemaker_session.delete_endpoint(endpoint_name=endpoint_name)
sagemaker_session.delete_endpoint_config(endpoint_config_name=endpoint_config_name)
sagemaker_session.delete_model(model_name=model_name)