# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import math import boto3 from collections import defaultdict from torch.utils.data import DataLoader from awsio.python.lib.io.s3.s3dataset import S3Dataset, S3IterableDataset, ShuffleDataset from awsio.python.lib.io.s3.s3dataset import tardata, zipdata def read_using_boto(bucket, prefix_list): s= boto3.client('s3') s3_obj_set = set() for prefix in prefix_list: fs = io.BytesIO() s.download_fileobj(bucket, prefix, fs) file_content = fs.getvalue() if prefix[-3:] == "tar": tarfile = tardata(file_content) for fname, content in tarfile: s3_obj_set.add((fname, content)) elif prefix[-3:] == "zip": zipfile = zipdata(file_content) for fname, content in zipfile: s3_obj_set.add((fname, content)) else: s3_obj_set.add((prefix.split("/")[-1], file_content)) return s3_obj_set def get_file_list(bucket, files_prefix): s3 = boto3.resource('s3') my_bucket = s3.Bucket(bucket) file_list = [summary.key for summary in my_bucket.objects.filter(Prefix=files_prefix)] return file_list[1:] def run_workers(dataset_type, url_list, batch_size, boto_obj_set): epochs = 2 dataset_class = eval(dataset_type) for num_workers in [ 0, 4, 16]: s3_obj_set = set() dataset = dataset_class(url_list) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) for epoch in range(epochs): print ("\nTesting " + dataset_type + " with {} workers for epoch {}".format( num_workers, epoch + 1)) num_batches = 0 for fname, fobj in dataloader: fname = [x.split("/")[-1] for x in fname] batch_set = set(map(tuple, zip(fname, fobj))) s3_obj_set.update(batch_set) num_batches += 1 assert s3_obj_set == boto_obj_set, "Test fails for {} workers for".format( num_workers) + dataset_type print ("All data correctly loaded for " + dataset_type + " for {} workers".format(num_workers)) def test_tarfiles(): bucket = "pt-s3plugin-test-data-west2" tarfiles_list = ["integration_tests/imagenet-train-000000.tar"] print("\nINITIATING: TARFILES READ TEST") boto_obj_set = read_using_boto(bucket, tarfiles_list) batch_size = 32 url_list = ["s3://" + bucket + "/" + tarfile for tarfile in tarfiles_list] run_workers("S3IterableDataset", url_list, batch_size, boto_obj_set) def test_files(): bucket = "pt-s3plugin-test-data-west2" files_prefix = "integration_tests/files" assert files_prefix[-1] != "/", "Enter Prefix without trailing \"/\" else error" prefix_list = get_file_list(bucket, files_prefix) boto_obj_set = read_using_boto(bucket, prefix_list) batch_size = 32 print ("\nINITIATING: INDIVIDUAL FILE READ TEST") url_list = ["s3://" + bucket + "/" + prefix for prefix in prefix_list] run_workers("S3IterableDataset", url_list, batch_size, boto_obj_set) run_workers("S3Dataset", url_list, batch_size, boto_obj_set) print ("\nINITIATING: READ FILES FROM PREFIX TEST") url_list = ["s3://" + bucket + "/" + files_prefix] run_workers("S3IterableDataset", url_list, batch_size, boto_obj_set) run_workers("S3Dataset", url_list, batch_size, boto_obj_set) def test_shuffleurls(): """ Args: bucket : name of the bucket files_prefix : prefix of the location where files stored Logic: Loop over dataloader twice, once with shuffle_urls as True and once as False After both runs, the dataloaded should be the same, the loading order should be different Maintains a dictionary each of sets and lists. The keys of the dictionary is the state of shuffle_urls(True/False) Values are the set/list of the samples Test passes if the set of samples loaded in both cases is same and the list of samples is diffrent(loading order different - data being shuffled) """ bucket = "pt-s3plugin-test-data-west2" files_prefix = "integration_tests/files" assert files_prefix[-1] != "/", "Enter Prefix without trailing \"/\" else error" prefix_list = get_file_list(bucket, files_prefix) url_list = ["s3://" + bucket + "/" + prefix for prefix in prefix_list] batch_size = 32 shuffled_sets = defaultdict(set) shuffled_lists = defaultdict(list) print ("\nINITIATING SHUFFLE TEST") for shuffle_urls in [True, False]: dataset = S3IterableDataset(url_list, shuffle_urls=shuffle_urls) dataloader = DataLoader(dataset, batch_size=batch_size) for fname, fobj in dataloader: fname = [x.split("/")[-1] for x in fname] batch_set = set(map(tuple, zip(fname, fobj))) batch_list = list(map(tuple, zip(fname, fobj))) shuffled_sets[str(shuffle_urls)].update(batch_set) shuffled_lists[str(shuffle_urls)].append(batch_list) assert shuffled_sets['True'] == shuffled_sets['False'] and shuffled_lists['True'] != shuffled_lists['False'], \ "Shuffling not working correctly" print ("Shuffle test passed for S3IterableDataset") def test_ShuffleDataset(): """ Args: bucket: name of the bucket tarfiles_list: list of all tarfiles with the prefix buffer_size: number of files the ShuffleDataset object caches Logic: Loop over the ShuffleDataset Dataloader twice For the runs, the corresponding batches returned should not be the same - ensures that shuffling is happening within tarfile constituents After both the runs, the overall dataloaded should be the same If either of these conditions fails, then test fails """ bucket = "pt-s3plugin-test-data-west2" tarfiles_list = ["integration_tests/imagenet-train-000000.tar", "integration_tests/imagenet-train-000001.tar"] url_list = ["s3://" + bucket + "/" + tarfile for tarfile in tarfiles_list] batch_size = 32 buffer_size = 300 for num_workers in [0, 16]: for buffer_size in [30, 300, 3000]: dataset = ShuffleDataset(S3IterableDataset(url_list), buffer_size=buffer_size) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) batch_list1 = get_batches(dataloader) batch_list2 = get_batches(dataloader) assert batches_shuffled(batch_list1, batch_list2), "ShuffleDataset Test fails: batches not shuffled" assert batches_congruent(batch_list1, batch_list2), "ShuffleDataset Test fails: data mismatch" print ("ShuffleDataset test passes for {} buffer_size & {} workers ".format( buffer_size, num_workers)) def get_batches(dataloader): """ Args: Pytorch Dataloader object returns a list of samples from the dataloader """ batch_list = [] count = 0 for fname, fobj in dataloader: fname = [x.split("/")[-1] for x in fname] batch_list.append(list(zip(fname, fobj))) count += 1 return batch_list def batches_shuffled(batch_list1, batch_list2): """ Ars: two lists of batches Returns True if the corresponding batches in lists are different Returns False otherwise """ for b1, b2 in zip(batch_list1, batch_list2): if b1 == b2: return False return True def batches_congruent(batch_list1, batch_list2): """ Args: two lists of batches Returns True if the samples in both the lists matches returns False otherwise """ batches1_flat = [sample for batch in batch_list1 for sample in batch] batches2_flat = [sample for batch in batch_list2 for sample in batch] return set(batches1_flat) == set(batches2_flat)