# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # coding: utf-8 # pylint: disable= """Dataset generator.""" __all__ = ['DataLoader'] import multiprocessing import multiprocessing.queues from multiprocessing.reduction import ForkingPickler import pickle import io import sys import numpy as np from . import sampler as _sampler from ... import nd, context def rebuild_ndarray(*args): """Rebuild ndarray from pickled shared memory""" # pylint: disable=no-value-for-parameter return nd.NDArray(nd.ndarray._new_from_shared_mem(*args)) def reduce_ndarray(data): """Reduce ndarray to shared memory handle""" return rebuild_ndarray, data._to_shared_mem() ForkingPickler.register(nd.NDArray, reduce_ndarray) class ConnectionWrapper(object): """Connection wrapper for multiprocessing that supports sending NDArray via shared memory.""" def __init__(self, conn): self._conn = conn def send(self, obj): """Send object""" buf = io.BytesIO() ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj) self.send_bytes(buf.getvalue()) def recv(self): """Receive object""" buf = self.recv_bytes() return pickle.loads(buf) def __getattr__(self, name): """Emmulate conn""" attr = self.__dict__.get('_conn', None) return getattr(attr, name) class Queue(multiprocessing.queues.Queue): """Wrapper for multiprocessing queue that dumps NDArray with shared memory.""" def __init__(self, *args, **kwargs): if sys.version_info[0] <= 2: super(Queue, self).__init__(*args, **kwargs) else: super(Queue, self).__init__(*args, ctx=multiprocessing.get_context(), **kwargs) self._reader = ConnectionWrapper(self._reader) self._writer = ConnectionWrapper(self._writer) self._send = self._writer.send self._recv = self._reader.recv def default_batchify_fn(data): """Collate data into batch.""" if isinstance(data[0], nd.NDArray): return nd.stack(*data) elif isinstance(data[0], tuple): data = zip(*data) return [default_batchify_fn(i) for i in data] else: data = np.asarray(data) return nd.array(data, dtype=data.dtype) def default_mp_batchify_fn(data): """Collate data into batch. Use shared memory for stacking.""" if isinstance(data[0], nd.NDArray): out = nd.empty((len(data),) + data[0].shape, dtype=data[0].dtype, ctx=context.Context('cpu_shared', 0)) return nd.stack(*data, out=out) elif isinstance(data[0], tuple): data = zip(*data) return [default_mp_batchify_fn(i) for i in data] else: data = np.asarray(data) return nd.array(data, dtype=data.dtype, ctx=context.Context('cpu_shared', 0)) def worker_loop(dataset, key_queue, data_queue, batchify_fn): """Worker loop for multiprocessing DataLoader.""" while True: idx, samples = key_queue.get() if idx is None: break batch = batchify_fn([dataset[i] for i in samples]) data_queue.put((idx, batch)) class _MultiWorkerIter(object): """Interal multi-worker iterator for DataLoader.""" def __init__(self, num_workers, dataset, batchify_fn, batch_sampler): assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) self._num_workers = num_workers self._dataset = dataset self._batchify_fn = batchify_fn self._batch_sampler = batch_sampler self._key_queue = Queue() self._data_queue = Queue(2*self._num_workers) self._data_buffer = {} self._rcvd_idx = 0 self._sent_idx = 0 self._iter = iter(self._batch_sampler) self._shutdown = False workers = [] for _ in range(self._num_workers): worker = multiprocessing.Process( target=worker_loop, args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn)) worker.daemon = True worker.start() workers.append(worker) # pre-fetch for _ in range(2 * self._num_workers): self._push_next() def __len__(self): return len(self._batch_sampler) def __del__(self): self.shutdown() def _push_next(self): """Assign next batch workload to workers.""" r = next(self._iter, None) if r is None: return self._key_queue.put((self._sent_idx, r)) self._sent_idx += 1 def __next__(self): assert not self._shutdown, "call __next__ after shutdown is forbidden" if self._rcvd_idx == self._sent_idx: assert not self._data_buffer, "Data buffer should be empty at this moment" self.shutdown() raise StopIteration while True: if self._rcvd_idx in self._data_buffer: batch = self._data_buffer.pop(self._rcvd_idx) self._rcvd_idx += 1 self._push_next() return batch idx, batch = self._data_queue.get() self._data_buffer[idx] = batch def next(self): return self.__next__() def __iter__(self): return self def shutdown(self): """Shutdown internal workers by pushing terminate signals.""" if not self._shutdown: for _ in range(self._num_workers): self._key_queue.put((None, None)) try: while not self._data_queue.empty(): self._data_queue.get() except IOError: pass self._shutdown = True class DataLoader(object): """Loads data from a dataset and returns mini-batches of data. Parameters ---------- dataset : Dataset Source dataset. Note that numpy and mxnet arrays can be directly used as a Dataset. batch_size : int Size of mini-batch. shuffle : bool Whether to shuffle the samples. sampler : Sampler The sampler to use. Either specify sampler or shuffle, not both. last_batch : {'keep', 'discard', 'rollover'} How to handle the last batch if batch_size does not evenly divide `len(dataset)`. keep - A batch with less samples than previous batches is returned. discard - The last batch is discarded if its incomplete. rollover - The remaining samples are rolled over to the next epoch. batch_sampler : Sampler A sampler that returns mini-batches. Do not specify batch_size, shuffle, sampler, and last_batch if batch_sampler is specified. batchify_fn : callable Callback function to allow users to specify how to merge samples into a batch. Defaults to `default_batchify_fn`:: def default_batchify_fn(data): if isinstance(data[0], nd.NDArray): return nd.stack(*data) elif isinstance(data[0], tuple): data = zip(*data) return [default_batchify_fn(i) for i in data] else: data = np.asarray(data) return nd.array(data, dtype=data.dtype) num_workers : int, default 0 The number of multiprocessing workers to use for data preprocessing. `num_workers > 0` is not supported on Windows yet. """ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, num_workers=0): self._dataset = dataset if batch_sampler is None: if batch_size is None: raise ValueError("batch_size must be specified unless " \ "batch_sampler is specified") if sampler is None: if shuffle: sampler = _sampler.RandomSampler(len(dataset)) else: sampler = _sampler.SequentialSampler(len(dataset)) elif shuffle: raise ValueError("shuffle must not be specified if sampler is specified") batch_sampler = _sampler.BatchSampler( sampler, batch_size, last_batch if last_batch else 'keep') elif batch_size is not None or shuffle or sampler is not None or \ last_batch is not None: raise ValueError("batch_size, shuffle, sampler and last_batch must " \ "not be specified if batch_sampler is specified.") self._batch_sampler = batch_sampler self._num_workers = num_workers if num_workers >= 0 else 0 if batchify_fn is None: if num_workers > 0: self._batchify_fn = default_mp_batchify_fn else: self._batchify_fn = default_batchify_fn else: self._batchify_fn = batchify_fn def __iter__(self): if self._num_workers == 0: generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch])) for batch in self._batch_sampler] return generator() # multi-worker return _MultiWorkerIter(self._num_workers, self._dataset, self._batchify_fn, self._batch_sampler) def __len__(self): return len(self._batch_sampler)