In [None]:
%load_ext nb_black

from problems.tsp.problem_tsp import TSP
from utils import load_model, move_to
from torch.utils.data import DataLoader
import boto3
import sagemaker

In [None]:
from inference import *

In [None]:
import pandas as pd

In [None]:
from sagemaker.serializers import JSONLinesSerializer
from sagemaker.deserializers import JSONLinesDeserializer

In [None]:
session = sagemaker.Session()
BUCKET = session.default_bucket() # Set a default S3 bucket

In [None]:
# set USE_PRETRAINED_MODEL to False if you have trained a model using pytorch_training.ipynb
USE_PRETRAINED_MODEL = True

In [None]:
PRETRAINED_MODEL_PATH = "../learning-tsp/pretrained/tsp_20-50/rl-ar-var-20pnn-gnn-max_20200313T002243/model.tar.gz"

# 1. Test inference code locally

## Prepare data

In [None]:
dataset_path = None
batch_size = 1
accumulation_steps = 80
num_samples = 2 # 1280 samples per TSP size

neighbors = 0.20
knn_strat = "percentage"

In [None]:
dataset = TSP.make_dataset(
 filename=dataset_path,
 batch_size=batch_size,
 num_samples=num_samples,
 min_size=10,
 max_size=10,
 neighbors=neighbors,
 knn_strat=knn_strat,
 supervised=False,
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
# transform data
data = []
for bat_idx, bat in enumerate(dataloader):
 input = {}
 input["nodes"] = bat["nodes"].tolist()
 data.append(input)
for record in data:
 record["neighbors"] = neighbors

In [None]:
data

## Prepare the model

In [None]:
# Getting the latest model data from the training jobs
def get_latest_model():
 client = boto3.client("sagemaker")

 # Get the trained sklearn model
 response = client.list_training_jobs(
 NameContains="pytorch-smdataparallel-tsp",
 StatusEquals="Completed",
 SortBy="CreationTime",
 SortOrder="Descending",
 )
 training_job_name = response["TrainingJobSummaries"][0]["TrainingJobName"]
 model_s3 = client.describe_training_job(TrainingJobName=training_job_name)[
 "ModelArtifacts"
 ]["S3ModelArtifacts"]
 return model_s3


# Upload a pretrained model to s3
def upload_pretrained_model():
 s3 = boto3.resource("s3")
 S3_PATH = PRETRAINED_MODEL_PATH.lstrip("../")
 s3.meta.client.upload_file(PRETRAINED_MODEL_PATH, BUCKET, S3_PATH)
 return f"s3://{BUCKET}/{S3_PATH}"

In [None]:
if USE_PRETRAINED_MODEL == True:
 model_data = upload_pretrained_model()
else:
 model_data = get_latest_model()

In [None]:
model_data

## Download the model locally for testing

In [None]:
!aws s3 cp $model_data ./

In [None]:
!mkdir -p model

In [None]:
!tar -xvzf ./model.tar.gz -C ./model/

In [None]:
!ls model

## Load model

In [None]:
model_dir = "./model"
model = model_fn(model_dir)

In [None]:
model

## Define input

In [None]:
serializer = JSONLinesSerializer()

data_jsonlines = serializer.serialize(data)

request_body = data_jsonlines.encode("utf-8")

input_data = input_fn(request_body)

with open("inference_input", "w") as file:
 file.write(data_jsonlines)

In [None]:
# upload to S3 for batch transform
!aws s3 cp inference_input s3://$BUCKET/data/inference/

## Prediction

In [None]:
input_data

In [None]:
prediction = predict_fn(input_data, model)

In [None]:
prediction

## Define output

In [None]:
output = output_fn(prediction)

In [None]:
output

In [None]:
with open("prediction", "w") as file:
 file.write(output[0])

# 2. Test inference code via endpoints

In [None]:
import sagemaker

role = sagemaker.get_execution_role()

from sagemaker.pytorch import PyTorchModel

model_sm = PyTorchModel(
 model_data=model_data,
 source_dir="../src",
 entry_point="inference.py",
 role=role,
 framework_version="1.8.1",
 py_version="py36",
)

In [None]:
predictor = model_sm.deploy(
 initial_instance_count=1,
 instance_type="ml.m4.xlarge",
 serializer=JSONLinesSerializer(),
 deserializer=JSONLinesDeserializer(),
)

In [None]:
# Send the sampled images to endpoint for inference
prediction = predictor.predict(data)

In [None]:
prediction

# 3. Batch Transform

In [None]:
transformer = model_sm.transformer(
 instance_count=1,
 instance_type="ml.m5.2xlarge",
 strategy="MultiRecord",
 assemble_with="Line",
 accept="application/jsonlines",
 max_concurrent_transforms=4,
 env={"SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600"},
)

In [None]:
transformer.transform(
 f"s3://{BUCKET}/data/inference/inference_input",
 content_type="application/jsonlines",
 split_type="Line",
 wait=False,
)

# 4. Clean up (Optional)

In [None]:
# Delete the SageMaker endpoint
predictor.delete_endpoint()

# Delete the SageMaker model
model_sm.delete_model()