'''SageMaker PyTorch inference container overrides to serve plain ResNet18 model''' # Python Built-Ins: import argparse from io import BytesIO import os # External Dependencies: import numpy as np from PIL import Image import torch import torch.nn as nn import torch.optim as optim from torchvision import models, transforms def model_fn(model_dir): #create model model = models.resnet18() #traffic sign dataset has 43 classes nfeatures = model.fc.in_features model.fc = nn.Linear(nfeatures, 43) #load model weights = torch.load(f'{model_dir}/model/model.pt', map_location=lambda storage, loc: storage) model.load_state_dict(weights) model.eval() model.cpu() return model def transform_fn(model, data, content_type, output_content_type): transform = transforms.Compose([ transforms.Resize((128,128)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) image = np.load(BytesIO(data)) image = Image.fromarray(image) image = transform(image) image = image.unsqueeze(0) #forward pass prediction = model(image) #get prediction predicted_class = prediction.data.max(1, keepdim=True)[1] response_body = np.array(predicted_class.cpu()).tolist() return response_body, output_content_type