#!/usr/bin/env python import io import sys import boto3 from collections import defaultdict from torch.utils.data import DataLoader from awsio.python.lib.io.s3.s3dataset import S3Dataset, S3IterableDataset, ShuffleDataset from awsio.python.lib.io.s3.s3dataset import tardata, zipdata def read_using_boto(bucket, prefix_list): s = boto3.client('s3') s3_obj_set = set() for prefix in prefix_list: fs = io.BytesIO() s.download_fileobj(bucket, prefix, fs) file_content = fs.getvalue() if prefix[-3:] == "tar": tarfile = tardata(file_content) for fname, content in tarfile: s3_obj_set.add((fname, content)) elif prefix[-3:] == "zip": zipfile = zipdata(file_content) for fname, content in zipfile: s3_obj_set.add((fname, content)) else: s3_obj_set.add((prefix.split("/")[-1], file_content)) return s3_obj_set def get_file_list(bucket, files_prefix): s3 = boto3.resource('s3') my_bucket = s3.Bucket(bucket) file_list = [summary.key for summary in my_bucket.objects.filter(Prefix=files_prefix)] return file_list def run_workers(dataset_type, url_list, batch_size, boto_obj_set): epochs = 2 dataset_class = eval(dataset_type) for num_workers in [0, 4, 16]: s3_obj_set = set() dataset = dataset_class(url_list) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) for epoch in range(epochs): print ("\nTesting " + dataset_type + " with {} workers for epoch {}".format( num_workers, epoch + 1)) num_batches = 0 for fname, fobj in dataloader: fname = [x.split("/")[-1] for x in fname] batch_set = set(map(tuple, zip(fname, fobj))) s3_obj_set.update(batch_set) num_batches += 1 assert s3_obj_set == boto_obj_set, "Test fails for {} workers for".format( num_workers) + dataset_type print ("All data correctly loaded for " + dataset_type + " for {} workers".format(num_workers)) def test_tarfiles(): bucket = "pt-s3plugin-test-data-west2" tarfiles_list = ["integration_tests/imagenet-train-000000.tar"] print("\nINITIATING: TARFILES READ TEST") boto_obj_set = read_using_boto(bucket, tarfiles_list) batch_size = 32 url_list = ["s3://" + bucket + "/" + tarfile for tarfile in tarfiles_list] run_workers("S3IterableDataset", url_list, batch_size, boto_obj_set) def test_files(): bucket = "pt-s3plugin-test-data-west2" files_prefix = "integration_tests/files" assert files_prefix[-1] != "/", "Enter Prefix without trailing \"/\" else error" prefix_list = get_file_list(bucket, files_prefix) boto_obj_set = read_using_boto(bucket, prefix_list) batch_size = 32 print ("\nINITIATING: INDIVIDUAL FILE READ TEST") url_list = ["s3://" + bucket + "/" + prefix for prefix in prefix_list] run_workers("S3IterableDataset", url_list, batch_size, boto_obj_set) run_workers("S3Dataset", url_list, batch_size, boto_obj_set) print ("\nINITIATING: READ FILES FROM PREFIX TEST") url_list = ["s3://" + bucket + "/" + files_prefix] run_workers("S3IterableDataset", url_list, batch_size, boto_obj_set) run_workers("S3Dataset", url_list, batch_size, boto_obj_set) def test_shuffleurls(): """ Args: bucket : name of the bucket files_prefix : prefix of the location where files stored Logic: Loop over dataloader twice, once with shuffle_urls as True and once as False After both runs, the dataloaded should be the same, the loading order should be different Maintains a dictionary each of sets and lists. The keys of the dictionary is the state of shuffle_urls(True/False) Values are the set/list of the samples Test passes if the set of samples loaded in both cases is same and the list of samples is diffrent(loading order different - data being shuffled) """ bucket = "pt-s3plugin-test-data-west2" files_prefix = "integration_tests/files" assert files_prefix[-1] != "/", "Enter Prefix without trailing \"/\" else error" prefix_list = get_file_list(bucket, files_prefix) url_list = ["s3://" + bucket + "/" + prefix for prefix in prefix_list] batch_size = 32 shuffled_sets = defaultdict(set) shuffled_lists = defaultdict(list) print ("\nINITIATING SHUFFLE TEST") for shuffle_urls in [True, False]: dataset = S3IterableDataset(url_list, shuffle_urls=shuffle_urls) dataloader = DataLoader(dataset, batch_size=batch_size) for fname, fobj in dataloader: fname = [x.split("/")[-1] for x in fname] batch_set = set(map(tuple, zip(fname, fobj))) batch_list = list(map(tuple, zip(fname, fobj))) shuffled_sets[str(shuffle_urls)].update(batch_set) shuffled_lists[str(shuffle_urls)].append(batch_list) assert shuffled_sets['True'] == shuffled_sets['False'] and shuffled_lists['True'] != shuffled_lists['False'], \ "Shuffling not working correctly" print ("Shuffle test passed for S3IterableDataset") def test_ShuffleDataset(): """ Args: bucket: name of the bucket tarfiles_list: list of all tarfiles with the prefix buffer_size: number of files the ShuffleDataset object caches Logic: Loop over the ShuffleDataset Dataloader twice For the runs, the corresponding batches returned should not be the same - ensures that shuffling is happening within tarfile constituents After both the runs, the overall dataloaded should be the same If either of these conditions fails, then test fails """ bucket = "pt-s3plugin-test-data-west2" tarfiles_list = ["integration_tests/imagenet-train-000000.tar", "integration_tests/imagenet-train-000001.tar"] url_list = ["s3://" + bucket + "/" + tarfile for tarfile in tarfiles_list] batch_size = 32 buffer_size = 300 for num_workers in [0, 16]: for buffer_size in [30, 300, 3000]: dataset = ShuffleDataset(S3IterableDataset(url_list), buffer_size=buffer_size) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) batch_list1 = get_batches(dataloader) batch_list2 = get_batches(dataloader) assert batches_shuffled(batch_list1, batch_list2), "ShuffleDataset Test fails: batches not shuffled" assert batches_congruent(batch_list1, batch_list2), "ShuffleDataset Test fails: data mismatch" print ("ShuffleDataset test passes for {} buffer_size & {} workers ".format( buffer_size, num_workers)) def get_batches(dataloader): """ Args: Pytorch Dataloader object returns a list of samples from the dataloader """ batch_list = [] count = 0 for fname, fobj in dataloader: fname = [x.split("/")[-1] for x in fname] batch_list.append(list(zip(fname, fobj))) count += 1 return batch_list def batches_shuffled(batch_list1, batch_list2): """ Ars: two lists of batches Returns True if the corresponding batches in lists are different Returns False otherwise """ for b1, b2 in zip(batch_list1, batch_list2): if b1 == b2: return False return True def batches_congruent(batch_list1, batch_list2): """ Args: two lists of batches Returns True if the samples in both the lists matches returns False otherwise """ batches1_flat = [sample for batch in batch_list1 for sample in batch] batches2_flat = [sample for batch in batch_list2 for sample in batch] return set(batches1_flat) == set(batches2_flat) def run_test(): test_shuffleurls() test_ShuffleDataset() test_files() test_tarfiles() if __name__ == '__main__': try: import torch.multiprocessing as mp mp.set_start_method('spawn') sys.exit(run_test()) except KeyboardInterrupt: pass