from __future__ import print_function import os import torch from model_def import Net def model_fn(model_dir): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Net() with open(os.path.join(model_dir, "model.pth"), "rb") as f: model.load_state_dict(torch.load(f)) return model.to(device)