In [None]:
%reload_ext autoreload
%autoreload 2


%matplotlib inline

# Amazon SageMaker

In [None]:
import base64
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

## Boilerplate

### Session

In [None]:
import boto3, time, json

sess = boto3.Session()
sm = sess.client("sagemaker")
region = sess.region_name
account = boto3.client("sts").get_caller_identity().get("Account")

### IAM Role

**Note**: make sure the IAM role has: 
- `AmazonS3FullAccess` 
- `AmazonEC2ContainerRegistryFullAccess` 
- `AmazonSageMakerFullAccess` 

In [None]:
import sagemaker

role = sagemaker.get_execution_role()
role

### Amazon Elastic Container Registry (ECR)

**Note**: create ECR if it doesn't exist

In [None]:
registry_name = "fastai-torchserve-sagemaker"
# !aws ecr create-repository --repository-name {registry_name}

In [None]:
image = f"{account}.dkr.ecr.{region}.amazonaws.com/{registry_name}:latest"
image

### Pytorch Model Artifact

Create a compressed `*.tar.gz` file from the `*.mar` file per requirement of Amazon SageMaker and upload the model to your Amazon S3 bucket.

In [None]:
model_file_name = "fastunet"
s3_bucket_name = ""
# !tar cvzf {model_file_name}.tar.gz fastunet.mar
# !aws s3 cp {model_file_name}.tar.gz s3://{s3_bucket_name}/

### Build a FastAI+TorchServe Docker container and push it to Amazon ECR

In [None]:
!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {account}.dkr.ecr.{region}.amazonaws.com
!docker build -t {registry_name} ../
!docker tag {registry_name}:latest {image}
!docker push {image}

### Model

In [None]:
model_data = f"s3://{s3_bucket_name}/{model_file_name}.tar.gz"
sm_model_name = "fastai-unet-torchserve-sagemaker"

container = {"Image": image, "ModelDataUrl": model_data}

create_model_response = sm.create_model(
 ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print(create_model_response["ModelArn"])

## Batch Transform

### S3 Input and Output

In [None]:
batch_input = f"s3://{s3_bucket_name}/batch_transform_fastai_torchserve_sagemaker/"
batch_output = f"s3://{s3_bucket_name}/batch_transform_fastai_torchserve_sagemaker_output/"

In [None]:
!aws s3 ls {batch_input}

In [None]:
import time

batch_job_name = 'fastunet-batch' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
batch_job_name

### Batch transform jobs

In [None]:
request = {
 "ModelClientConfig": {
 "InvocationsTimeoutInSeconds": 3600,
 "InvocationsMaxRetries": 1,
 },
 "TransformJobName": batch_job_name,
 "ModelName": sm_model_name,
 "BatchStrategy": "MultiRecord",
 "TransformOutput": {"S3OutputPath": batch_output, "AssembleWith": "Line"},
 "TransformInput": {
 "DataSource": {
 "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": batch_input}
 },
 "CompressionType": "None",
 },
 "TransformResources": {"InstanceType": "ml.p2.xlarge", "InstanceCount": 1},
}

In [None]:
%%time
sm.create_transform_job(**request)

while True:
 response = sm.describe_transform_job(TransformJobName=batch_job_name)
 status = response["TransformJobStatus"]
 if status == "Completed":
 print("Transform job ended with status: " + status)
 break
 if status == "Failed":
 message = response["FailureReason"]
 print("Transform failed with the following error: {}".format(message))
 raise Exception("Transform job failed")
 print("Transform job is still in status: " + status)
 time.sleep(30)

### Testing

In [None]:
s3 = boto3.resource("s3")
s3.Bucket(f"{s3_bucket_name}").download_file(
 "batch_transform_fastai_torchserve_sagemaker_output/street_view_of_a_small_neighborhood.png.out",
 "street_view_of_a_small_neighborhood.txt",
)
s3.Bucket(f"{s3_bucket_name}").download_file(
 "batch_transform_fastai_torchserve_sagemaker/street_view_of_a_small_neighborhood.png",
 "street_view_of_a_small_neighborhood.png",
)

In [None]:
with open("street_view_of_a_small_neighborhood.txt") as f:
 results = f.read()

response = json.loads(results)

In [None]:
pred_decoded_byte = base64.decodebytes(bytes(response["base64_prediction"], encoding="utf-8"))
pred_decoded = np.reshape(
 np.frombuffer(pred_decoded_byte, dtype=np.uint8), (96, 128)
)
plt.imshow(pred_decoded);

## Inference Endpoint

### Endpoint configuration

**Note**: choose your preferred `InstanceType`: https://aws.amazon.com/sagemaker/pricing/

In [None]:
import time

endpoint_config_name = "torchserve-endpoint-config-" + time.strftime(
 "%Y-%m-%d-%H-%M-%S", time.gmtime()
)
print(endpoint_config_name)

create_endpoint_config_response = sm.create_endpoint_config(
 EndpointConfigName=endpoint_config_name,
 ProductionVariants=[
 {
 "InstanceType": "ml.g4dn.xlarge",
 "InitialVariantWeight": 1,
 "InitialInstanceCount": 1,
 "ModelName": sm_model_name,
 "VariantName": "AllTraffic",
 }
 ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

### Endpoint

In [None]:
endpoint_name = "fastunet-torchserve-endpoint-" + time.strftime(
 "%Y-%m-%d-%H-%M-%S", time.gmtime()
)
print(endpoint_name)

create_endpoint_response = sm.create_endpoint(
 EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
print(create_endpoint_response["EndpointArn"])

In [None]:
%%time
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
 time.sleep(60)
 resp = sm.describe_endpoint(EndpointName=endpoint_name)
 status = resp["EndpointStatus"]
 print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### Testing

In [None]:
file_name = "../sample/street_view_of_a_small_neighborhood.png"

with open(file_name, 'rb') as f:
 payload = f.read()
 
Image.open(file_name)

In [None]:
%%time
client = boto3.client("runtime.sagemaker")
response = client.invoke_endpoint(
 EndpointName=endpoint_name, ContentType="application/x-image", Body=payload
)
response = json.loads(response["Body"].read())

In [None]:
pred_decoded_byte = base64.decodebytes(bytes(response["base64_prediction"], encoding="utf-8"))
pred_decoded = np.reshape(
 np.frombuffer(pred_decoded_byte, dtype=np.uint8), (96, 128)
)
plt.imshow(pred_decoded);

### Cleanup

In [None]:
client = boto3.client("sagemaker")
client.delete_model(ModelName=sm_model_name)
client.delete_endpoint(EndpointName=endpoint_name)
client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)