import json import os import shutil import tarfile import boto3 import botocore import numpy as np import sagemaker from inference import input_fn, model_fn, output_fn, predict_fn def fetch_model(model_data): """Untar the model.tar.gz object either from local file system or a S3 location Args: model_data (str): either a path to local file system starts with file:/// that points to the `model.tar.gz` file or an S3 link starts with s3:// that points to the `model.tar.gz` file Returns: model_dir (str): the directory that contains the uncompress model checkpoint files """ model_dir = "/home/ec2-user/SageMaker/tmp/model" # if not os.path.exists(model_dir): # os.makedirs(model_dir) # if model_data.startswith("file"): # # _check_model(model_data) # shutil.copy2( # os.path.join(model_dir, "model.tar.gz"), os.path.join(model_dir, "model.tar.gz") # ) # elif model_data.startswith("s3"): # # get bucket name and object key # bucket_name = model_data.split("/")[2] # key = "/".join(model_data.split("/")[3:]) # s3 = boto3.resource("s3") # try: # s3.Bucket(bucket_name).download_file(key, os.path.join(model_dir, "model.tar.gz")) # except botocore.exceptions.ClientError as e: # if e.response["Error"]["Code"] == "404": # print("the object does not exist.") # else: # raise # untar the model tar = tarfile.open(os.path.join(model_dir, "model.tar.gz")) tar.extractall(model_dir) tar.close() return model_dir def test(model_data): # decompress the model.tar.gz file model_dir = fetch_model(model_data) # load the model net = model_fn(model_dir) # simulate some input data to test transform_fn # data = {"inputs": np.random.rand(4, 4, 512, 512).tolist()} data = {"inputs": 's3://sagemaker-us-east-1-037978457745/data/your-file.pkl'} # encode numpy array to binary stream serializer = sagemaker.serializers.JSONSerializer() jstr = serializer.serialize(data) jstr = json.dumps(data) # "send" the bin_stream to the endpoint for inference # inference container calls transform_fn to make an inference # and get the response body for the caller content_type = "application/json" input_object = input_fn(jstr, content_type) predictions = predict_fn(input_object, net) res = output_fn(predictions, content_type) print(res) return if __name__ == "__main__": # model_data = "s3://sagemaker-us-west-2-688520471316/mxnet/mnist/pytorch-training-2020-11-21-22-02-56-203/model.tar.gz" model_data = "file:///home/ec2-user/SageMaker/vegetation-management-remars2022/remars2022-workshop/model.tar.gz" test(model_data)