from chainer.datasets import tuple_dataset import numpy as np def preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype, rgb_format): images = raw['x'] if ndim == 2: images = images.reshape(-1, 28, 28) elif ndim == 3: images = images.reshape(-1, 1, 28, 28) if rgb_format: images = np.broadcast_to(images, (len(images), 3) + images.shape[2:]) elif ndim != 1: raise ValueError('invalid ndim for MNIST dataset') images = images.astype(image_dtype) images *= scale / 255. if withlabel: labels = raw['y'].astype(label_dtype) return tuple_dataset.TupleDataset(images, labels) return images