# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import matplotlib.pyplot as plt import numpy as np import torch import torchvision import torchvision.transforms as transforms classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") def _get_transform(): return torchvision.transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) def train_data_loader(): transform = _get_transform() trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=False, transform=transform ) return torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) def test_data_loader(): transform = _get_transform() testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=False, transform=transform ) return torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) def show_img(img): """displays an image""" img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0)))