In [None]:
%reload_ext autoreload
%autoreload 2

%matplotlib inline

# TorchServe

In [None]:
!pygmentize ../deployment/handler.py

In [None]:
%%bash
cd ../

cp model/resnet50_0.962_head.pth model/head_weight.pth
cp model/resnet50_0.962_encoder.pth model/encoder_weight.pth

torch-model-archiver --model-name twin \
--version 1.0 --serialized-file ./model/encoder_weight.pth \
--export-path model_store --handler ./deployment/handler.py \
-f --extra-files ./model/head_weight.pth

ls -lh ./model_store/

# Amazon SageMaker

## Boilerplate

In [None]:
# !pip install boto3
# !pip install sagemaker

In [None]:
import requests
import boto3, time

### Session

In [None]:
sess = boto3.Session()
sm = sess.client("sagemaker")
region = sess.region_name

account = boto3.client("sts").get_caller_identity().get("Account")
region, account

### IAM Role

Note: make sure the IAM role has:
- `AmazonS3FullAccess`
- `AmazonEC2ContainerRegistryFullAccess`
- `AmazonSageMakerFullAccess`

In [None]:
import sagemaker

role = sagemaker.get_execution_role()
role

## Amazon Elastic Container Registry (ECR)

**Note**: create ECR if it doesn’t exist

In [None]:
registry_name = "twin-pytorch"
image = f"{account}.dkr.ecr.{region}.amazonaws.com/{registry_name}:latest"
image

In [None]:
!aws ecr create-repository --repository-name {registry_name} --region {region}

### Pytorch Model Artifact

Create a compressed `*.tar.gz` file from the `*.mar` file per requirement of Amazon SageMaker and upload the model to your Amazon S3 bucket.

In [None]:
model_file_name = "twin"
s3_bucket_name = ""

In [None]:
%%bash -s "$model_file_name" "$s3_bucket_name"
cd ../model_store/
tar cvfz $1.tar.gz $1.mar
aws s3 cp $1.tar.gz s3://$2/

### Build TorchServe Docker container and push it to Amazon ECR

In [None]:
!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {account}.dkr.ecr.{region}.amazonaws.com

In [None]:
%%bash -s "$registry_name" "$image"
cd ../
docker build -t $1 .
docker tag $1 $2
docker push $2

### SageMaker Model

In [None]:
model_data = f"s3://{s3_bucket_name}/{model_file_name}.tar.gz"
sm_model_name = "torchserve-twin-v1"

container = {"Image": image, "ModelDataUrl": model_data}

create_model_response = sm.create_model(
 ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print(create_model_response["ModelArn"])

## Inference Endpoint

Configuration with **InstanceType** and **Model Monitoring**

In [None]:
endpoint_config_name = "torchserve-endpoint-config-" + time.strftime(
 "%Y-%m-%d-%H-%M-%S", time.gmtime()
)
print(endpoint_config_name)

create_endpoint_config_response = sm.create_endpoint_config(
 EndpointConfigName=endpoint_config_name,
 ProductionVariants=[
 {
 "InstanceType": "ml.g4dn.xlarge", # Choose Your Preferred Instance Type
 "InitialVariantWeight": 1,
 "InitialInstanceCount": 1,
 "ModelName": sm_model_name,
 "VariantName": "AllTraffic",
 }
 ],
 # DataCaptureConfig={
 # "EnableCapture": True,
 # "InitialSamplingPercentage": 100,
 # "DestinationS3Uri": f"s3://{s3_bucket_name}/monitor/",
 # "CaptureOptions": [
 # {"CaptureMode": "Input"},
 # {"CaptureMode": "Output"},
 # ],
 # },
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

### Endpoint

In [None]:
endpoint_name = "torchserve-endpoint-" + time.strftime(
 "%Y-%m-%d-%H-%M-%S", time.gmtime()
)
print(endpoint_name)

create_endpoint_response = sm.create_endpoint(
 EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
print(create_endpoint_response["EndpointArn"])

In [None]:
%%time
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
 time.sleep(60)
 resp = sm.describe_endpoint(EndpointName=endpoint_name)
 status = resp["EndpointStatus"]
 print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### Testing

In [None]:
cam = True
r = requests.Request(
 "POST",
 "http://localhost:8080/invocations",
 files={
 "left": open("../sample/c1.jpg", "rb"),
 "right": open("../sample/c3.jpg", "rb"),
 },
 data={"cam": str(cam)}
)
r = r.prepare()
content_type = r.headers["Content-Type"]
payload = r.body
content_type, type(payload)

In [None]:
client = boto3.client("sagemaker-runtime")
response = client.invoke_endpoint(
 EndpointName=endpoint_name, ContentType=content_type, Body=payload
 )
res = response["Body"].read()
neg, pos, *maps = eval(res)
neg, pos

In [None]:
import torch
import matplotlib.pyplot as plt


if cam:
 length = len(maps)
 cam_map_left, cam_map_right = maps[: length // 2], maps[length // 2 :]

 cam_map_left = torch.tensor(cam_map_left)
 cam_map_right = torch.tensor(cam_map_right)

 _, ax = plt.subplots(1, 2, figsize=(10, 5))
 ax[0].imshow(
 cam_map_left,
 alpha=0.6,
 extent=(0, 224, 224, 0),
 interpolation="bilinear",
 cmap="jet",
 )
 ax[0].axis("off")

 ax[1].imshow(
 cam_map_right,
 alpha=0.6,
 extent=(0, 224, 224, 0),
 interpolation="bilinear",
 cmap="jet",
 )
 ax[1].axis("off")

 plt.show()

## Cleanup

In [None]:
# client = boto3.client("sagemaker")
# client.delete_endpoint(EndpointName=endpoint_name)
# client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
# client.delete_model(ModelName=sm_model_name)