# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # # https://aws.amazon.com/apache-2-0/ # # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import os import sys import tarfile from six.moves import cPickle as pickle from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf CIFAR_FILENAME = 'cifar-10-python.tar.gz' CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py' def download_and_extract(data_dir): # download CIFAR-10 if not already downloaded. tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir, CIFAR_DOWNLOAD_URL) tarfile.open(os.path.join(data_dir, CIFAR_FILENAME), 'r:gz').extractall(data_dir) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _get_file_names(): """Returns the file names expected to exist in the input_dir.""" file_names = {} file_names['train'] = ['data_batch_%d' % i for i in xrange(1, 5)] file_names['validation'] = ['data_batch_5'] file_names['eval'] = ['test_batch'] return file_names def read_pickle_from_file(filename): with tf.gfile.Open(filename, 'rb') as f: if sys.version_info >= (3, 0): data_dict = pickle.load(f, encoding='bytes') else: data_dict = pickle.load(f) return data_dict def convert_to_tfrecord(input_files, output_file): """Converts a file to TFRecords.""" print('Generating %s' % output_file) with tf.python_io.TFRecordWriter(output_file) as record_writer: for input_file in input_files: data_dict = read_pickle_from_file(input_file) data = data_dict[b'data'] labels = data_dict[b'labels'] num_entries_in_batch = len(labels) for i in range(num_entries_in_batch): example = tf.train.Example(features=tf.train.Features( feature={ 'image': _bytes_feature(data[i].tobytes()), 'label': _int64_feature(labels[i]) })) record_writer.write(example.SerializeToString()) def main(data_dir): print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL)) download_and_extract(data_dir) file_names = _get_file_names() input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER) for mode, files in file_names.items(): input_files = [os.path.join(input_dir, f) for f in files] output_file = os.path.join(data_dir+'/'+mode, mode + '.tfrecords') if not os.path.exists(data_dir+'/'+mode): os.makedirs(data_dir+'/'+mode) try: os.remove(output_file) except OSError: pass # Convert to tf.train.Example and write the to TFRecords. convert_to_tfrecord(input_files, output_file) print('Done!') import shutil shutil.rmtree(data_dir+'/cifar-10-batches-py') os.remove(data_dir+'/cifar-10-python.tar.gz') # Remove the original .tzr.gz files if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--data-dir', type=str, default='', help='Directory to download and extract CIFAR-10 to.') args = parser.parse_args() main(args.data_dir)