import json import logging import sys import os import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import boto3 import io import pickle s3_client = boto3.client('s3') logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) from torchvision.models import vgg16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_modified_vgg16_unet(in_channels=3): """ Get a modified VGG16-Unet model with customized input channel numbers. """ class Modified_VGG16Unet(VGG16Unet): def __init__(self): super().__init__(in_channels=in_channels) return Modified_VGG16Unet class DecoderBlock(nn.Module): def __init__(self, in_channels, middle_channels, out_channels): super().__init__() self.in_channels = in_channels self.block = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), ConvRelu(in_channels, middle_channels), ConvRelu(middle_channels, out_channels), ) def forward(self, x): return self.block(x) class ConvRelu(nn.Module): def __init__(self, in_, out): super().__init__() self.conv = nn.Conv2d(in_, out, 3, padding=1) self.activation = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.activation(x) return x class VGG16Unet(nn.Module): def __init__(self, in_channels=3, num_filters=32, pretrained=False): super().__init__() # Get VGG16 net as encoder self.encoder = vgg16(pretrained=pretrained).features self.pool = nn.MaxPool2d(2, 2) self.relu = nn.ReLU(inplace=True) # Modify encoder architecture self.encoder[0] = nn.Conv2d( in_channels, 64, kernel_size=3, stride=1, padding=1) self.conv1 = nn.Sequential( self.encoder[0], self.relu, self.encoder[2], self.relu) self.conv2 = nn.Sequential( self.encoder[5], self.relu, self.encoder[7], self.relu) self.conv3 = nn.Sequential( self.encoder[10], self.relu, self.encoder[12], self.relu, self.encoder[14], self.relu) self.conv4 = nn.Sequential( self.encoder[17], self.relu, self.encoder[19], self.relu, self.encoder[21], self.relu) self.conv5 = nn.Sequential( self.encoder[24], self.relu, self.encoder[26], self.relu, self.encoder[28], self.relu) # Build decoder self.center = DecoderBlock( 512, num_filters*8*2, num_filters*8) self.dec5 = DecoderBlock( 512 + num_filters*8, num_filters*8*2, num_filters*8) self.dec4 = DecoderBlock( 512 + num_filters*8, num_filters*8*2, num_filters*8) self.dec3 = DecoderBlock( 256 + num_filters*8, num_filters*4*2, num_filters*2) self.dec2 = DecoderBlock( 128 + num_filters*2, num_filters*2*2, num_filters) self.dec1 = ConvRelu(64 + num_filters, num_filters) # Final output layer outputs logits, not probability self.final = nn.Conv2d(num_filters, 1, kernel_size=1) def forward(self, x): conv1 = self.conv1(x) conv2 = self.conv2(self.pool(conv1)) conv3 = self.conv3(self.pool(conv2)) conv4 = self.conv4(self.pool(conv3)) conv5 = self.conv5(self.pool(conv4)) center = self.center(self.pool(conv5)) dec5 = self.dec5(torch.cat([center, conv5], 1)) dec4 = self.dec4(torch.cat([dec5, conv4], 1)) dec3 = self.dec3(torch.cat([dec4, conv3], 1)) dec2 = self.dec2(torch.cat([dec3, conv2], 1)) dec1 = self.dec1(torch.cat([dec2, conv1], 1)) x_out = self.final(dec1) return x_out # defining model and loading weights to it. def model_fn(model_dir): model = get_modified_vgg16_unet(in_channels=4)() with open(os.path.join(model_dir, "model.pth"), "rb") as f: model.load_state_dict(torch.load(f, map_location='cpu')) model.to(device).eval() return model # data preprocessing def input_fn(request_body, request_content_type): # assert request_content_type == "application/json" # Get bucket name and file from the input path s3_path = json.loads(request_body)["inputs"] global path_parts path_parts=s3_path.replace("s3://","").split("/") global BUCKET_NAME BUCKET_NAME=path_parts.pop(0) BUCKET_FILE_NAME="/".join(path_parts) # Extract data from the s3 bucket my_array_data2 = io.BytesIO() s3_client.download_fileobj(BUCKET_NAME, BUCKET_FILE_NAME, my_array_data2) my_array_data2.seek(0) global idx_refs, src_im_height, src_im_width data, idx_refs, (src_im_height, src_im_width) = pickle.load(my_array_data2) data = torch.tensor(data, dtype=torch.float32, device=device) return data # inference def predict_fn(data, model): with torch.no_grad(): model.eval() subarr_preds_list = [] batch_size=1 for batch_i in range(0, data.shape[0], batch_size): if batch_i + batch_size <= data.shape[0]: subarr_pred = model(data[ batch_i:batch_i+batch_size, ... ]) else: subarr_pred = model(data[ batch_i:, ... ]) subarr_preds_list.append(subarr_pred.cpu().data.numpy()) subarr_preds = np.concatenate(subarr_preds_list, axis=0) return subarr_preds # postprocess def output_fn(subarr_preds, content_type): # assert content_type == "application/json" # upload without using disk my_array_data = io.BytesIO() pickle.dump([subarr_preds,idx_refs, src_im_height, src_im_width] , my_array_data) my_array_data.seek(0) OUTPUT_FILE = "/".join(path_parts[:-1]) + '/output_pred.pkl' s3_client.upload_fileobj(my_array_data, BUCKET_NAME, OUTPUT_FILE) file_name = 's3://' + BUCKET_NAME + '/' + OUTPUT_FILE return json.dumps(file_name)