# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) 2018-2019 NVIDIA CORPORATION. All rights reserved. from torch.utils.data.sampler import BatchSampler class IterationBasedBatchSampler(BatchSampler): """ Wraps a BatchSampler, resampling from it until a specified number of iterations have been sampled """ def __init__(self, batch_sampler, num_iterations, start_iter=0, random_number_generator=None): self.batch_sampler = batch_sampler self.num_iterations = num_iterations self.start_iter = start_iter self.random_number_generator = random_number_generator def __iter__(self): iteration = self.start_iter while iteration <= self.num_iterations: # if the underlying sampler has a set_epoch method, like # DistributedSampler, used for making each process see # a different split of the dataset, then set it if hasattr(self.batch_sampler.sampler, "set_epoch"): if self.random_number_generator is not None: iteration_seed = self.random_number_generator.randint(0, 2 ** 32 - 1) self.batch_sampler.sampler.set_epoch(iteration_seed) else: self.batch_sampler.sampler.set_epoch(iteration) for batch in self.batch_sampler: iteration += 1 if iteration > self.num_iterations: break yield batch def __len__(self): return self.num_iterations