import torch import torch.nn as nn import torch.nn.functional as F import os class NetM3(nn.Module): def __init__(self): super(NetM3, self).__init__() self.conv1 = nn.Conv1d(1, 128, 80, 4) self.bn1 = nn.BatchNorm1d(128) self.pool1 = nn.MaxPool1d(4) self.conv2 = nn.Conv1d(128, 128, 3) self.bn2 = nn.BatchNorm1d(128) self.pool2 = nn.MaxPool1d(4) self.conv3 = nn.Conv1d(128, 256, 3) self.bn3 = nn.BatchNorm1d(256) self.pool3 = nn.MaxPool1d(4) self.conv4 = nn.Conv1d(256, 512, 3) self.bn4 = nn.BatchNorm1d(512) self.pool4 = nn.MaxPool1d(4) self.avgPool = nn.AvgPool1d(30) # input should be 512x30 so this outputs a 512x1 self.fc1 = nn.Linear(512, 10) def forward(self, x): x = self.conv1(x) x = F.relu(self.bn1(x)) x = self.pool1(x) x = self.conv2(x) x = F.relu(self.bn2(x)) x = self.pool2(x) x = self.conv3(x) x = F.relu(self.bn3(x)) x = self.pool3(x) x = self.conv4(x) x = F.relu(self.bn4(x)) x = self.pool4(x) x = self.avgPool(x) x = x.permute(0, 2, 1) # change the 512x1 to 1x512 x = self.fc1(x) return F.log_softmax(x, dim=2) # Output: torch.Size([N, 1, 10]) def model_fn(model_dir): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = NetM3() if torch.cuda.device_count() > 1: print("Gpu count: {}".format(torch.cuda.device_count())) model = nn.DataParallel(model) with open(os.path.join(model_dir, "model.pth"), "rb") as f: model.load_state_dict(torch.load(f)) return model.to(device)