# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # coding: utf-8 # pylint: disable=ungrouped-imports """Dataset generator.""" __all__ = ['DataLoader'] import pickle import io import sys import multiprocessing import multiprocessing.queues from multiprocessing.reduction import ForkingPickler import threading import numpy as np try: import multiprocessing.resource_sharer except ImportError: pass from . import sampler as _sampler from ... import nd, context from ...recordio import MXRecordIO if sys.platform == 'darwin' or sys.platform == 'win32': def rebuild_ndarray(*args): """Rebuild ndarray from pickled shared memory""" # pylint: disable=no-value-for-parameter return nd.NDArray(nd.ndarray._new_from_shared_mem(*args)) def reduce_ndarray(data): """Reduce ndarray to shared memory handle""" return rebuild_ndarray, data._to_shared_mem() else: def rebuild_ndarray(pid, fd, shape, dtype): """Rebuild ndarray from pickled shared memory""" # pylint: disable=no-value-for-parameter if sys.version_info[0] == 2: fd = multiprocessing.reduction.rebuild_handle(fd) else: fd = fd.detach() return nd.NDArray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype)) def reduce_ndarray(data): """Reduce ndarray to shared memory handle""" # keep a local ref before duplicating fd data = data.as_in_context(context.Context('cpu_shared', 0)) pid, fd, shape, dtype = data._to_shared_mem() if sys.version_info[0] == 2: fd = multiprocessing.reduction.reduce_handle(fd) else: fd = multiprocessing.reduction.DupFd(fd) return rebuild_ndarray, (pid, fd, shape, dtype) ForkingPickler.register(nd.NDArray, reduce_ndarray) class ConnectionWrapper(object): """Connection wrapper for multiprocessing that supports sending NDArray via shared memory.""" def __init__(self, conn): self._conn = conn def send(self, obj): """Send object""" buf = io.BytesIO() ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj) self.send_bytes(buf.getvalue()) def recv(self): """Receive object""" buf = self.recv_bytes() return pickle.loads(buf) def __getattr__(self, name): """Emmulate conn""" attr = self.__dict__.get('_conn', None) return getattr(attr, name) class Queue(multiprocessing.queues.Queue): """Wrapper for multiprocessing queue that dumps NDArray with shared memory.""" def __init__(self, *args, **kwargs): if sys.version_info[0] <= 2: super(Queue, self).__init__(*args, **kwargs) else: super(Queue, self).__init__(*args, ctx=multiprocessing.get_context(), **kwargs) self._reader = ConnectionWrapper(self._reader) self._writer = ConnectionWrapper(self._writer) self._send = self._writer.send self._recv = self._reader.recv class SimpleQueue(multiprocessing.queues.SimpleQueue): """Wrapper for multiprocessing SimpleQueue that dumps NDArray with shared memory. SimpleQueue don't use threading internally. """ def __init__(self, *args, **kwargs): if sys.version_info[0] <= 2: super(SimpleQueue, self).__init__(*args, **kwargs) else: super(SimpleQueue, self).__init__(*args, ctx=multiprocessing.get_context(), **kwargs) self._reader = ConnectionWrapper(self._reader) self._writer = ConnectionWrapper(self._writer) self._send = self._writer.send self._recv = self._reader.recv def default_batchify_fn(data): """Collate data into batch.""" if isinstance(data[0], nd.NDArray): return nd.stack(*data) elif isinstance(data[0], tuple): data = zip(*data) return [default_batchify_fn(i) for i in data] else: data = np.asarray(data) return nd.array(data, dtype=data.dtype) def default_mp_batchify_fn(data): """Collate data into batch. Use shared memory for stacking.""" if isinstance(data[0], nd.NDArray): out = nd.empty((len(data),) + data[0].shape, dtype=data[0].dtype, ctx=context.Context('cpu_shared', 0)) return nd.stack(*data, out=out) elif isinstance(data[0], tuple): data = zip(*data) return [default_mp_batchify_fn(i) for i in data] else: data = np.asarray(data) return nd.array(data, dtype=data.dtype, ctx=context.Context('cpu_shared', 0)) def _as_in_context(data, ctx): """Move data into new context.""" if isinstance(data, nd.NDArray): return data.as_in_context(ctx) elif isinstance(data, (list, tuple)): return [_as_in_context(d, ctx) for d in data] return data def _recursive_fork_recordio(obj, depth, max_depth=1000): """Recursively find instance of MXRecordIO and reset file handler. This is required for MXRecordIO which holds a C pointer to a opened file after fork. """ if depth >= max_depth: return if isinstance(obj, MXRecordIO): obj.close() obj.open() # re-obtain file hanlder in new process elif (hasattr(obj, '__dict__')): for _, v in obj.__dict__.items(): _recursive_fork_recordio(v, depth + 1, max_depth) def worker_loop(dataset, key_queue, data_queue, batchify_fn): """Worker loop for multiprocessing DataLoader.""" # re-fork a new recordio handler in new process if applicable # for a dataset with transform function, the depth of MXRecordIO is 1 # for a lazy transformer, the depth is 2 # for a user defined transformer, the depth is unknown, try a reasonable depth limit = sys.getrecursionlimit() max_recursion_depth = min(limit - 5, max(10, limit // 2)) _recursive_fork_recordio(dataset, 0, max_recursion_depth) while True: idx, samples = key_queue.get() if idx is None: break batch = batchify_fn([dataset[i] for i in samples]) data_queue.put((idx, batch)) def fetcher_loop(data_queue, data_buffer, pin_memory=False): """Fetcher loop for fetching data from queue and put in reorder dict.""" while True: idx, batch = data_queue.get() if idx is None: break if pin_memory: batch = _as_in_context(batch, context.cpu_pinned()) else: batch = _as_in_context(batch, context.cpu()) data_buffer[idx] = batch class _MultiWorkerIter(object): """Interal multi-worker iterator for DataLoader.""" def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, worker_fn=worker_loop): assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) self._num_workers = num_workers self._dataset = dataset self._batchify_fn = batchify_fn self._batch_sampler = batch_sampler self._key_queue = Queue() self._data_queue = Queue() if sys.version_info[0] <= 2 else SimpleQueue() self._data_buffer = {} self._rcvd_idx = 0 self._sent_idx = 0 self._iter = iter(self._batch_sampler) self._shutdown = False workers = [] for _ in range(self._num_workers): worker = multiprocessing.Process( target=worker_fn, args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn)) worker.daemon = True worker.start() workers.append(worker) self._fetcher = threading.Thread( target=fetcher_loop, args=(self._data_queue, self._data_buffer, pin_memory)) self._fetcher.daemon = True self._fetcher.start() # pre-fetch for _ in range(2 * self._num_workers): self._push_next() def __len__(self): return len(self._batch_sampler) def __del__(self): self.shutdown() def _push_next(self): """Assign next batch workload to workers.""" r = next(self._iter, None) if r is None: return self._key_queue.put((self._sent_idx, r)) self._sent_idx += 1 def __next__(self): assert not self._shutdown, "call __next__ after shutdown is forbidden" if self._rcvd_idx == self._sent_idx: assert not self._data_buffer, "Data buffer should be empty at this moment" self.shutdown() raise StopIteration while True: if self._rcvd_idx in self._data_buffer: batch = self._data_buffer.pop(self._rcvd_idx) self._rcvd_idx += 1 self._push_next() return batch def next(self): return self.__next__() def __iter__(self): return self def shutdown(self): """Shutdown internal workers by pushing terminate signals.""" if not self._shutdown: for _ in range(self._num_workers): self._key_queue.put((None, None)) self._data_queue.put((None, None)) self._shutdown = True class DataLoader(object): """Loads data from a dataset and returns mini-batches of data. Parameters ---------- dataset : Dataset Source dataset. Note that numpy and mxnet arrays can be directly used as a Dataset. batch_size : int Size of mini-batch. shuffle : bool Whether to shuffle the samples. sampler : Sampler The sampler to use. Either specify sampler or shuffle, not both. last_batch : {'keep', 'discard', 'rollover'} How to handle the last batch if batch_size does not evenly divide `len(dataset)`. keep - A batch with less samples than previous batches is returned. discard - The last batch is discarded if its incomplete. rollover - The remaining samples are rolled over to the next epoch. batch_sampler : Sampler A sampler that returns mini-batches. Do not specify batch_size, shuffle, sampler, and last_batch if batch_sampler is specified. batchify_fn : callable Callback function to allow users to specify how to merge samples into a batch. Defaults to `default_batchify_fn`:: def default_batchify_fn(data): if isinstance(data[0], nd.NDArray): return nd.stack(*data) elif isinstance(data[0], tuple): data = zip(*data) return [default_batchify_fn(i) for i in data] else: data = np.asarray(data) return nd.array(data, dtype=data.dtype) num_workers : int, default 0 The number of multiprocessing workers to use for data preprocessing. pin_memory : boolean, default False If ``True``, the dataloader will copy NDArrays into pinned memory before returning them. Copying from CPU pinned memory to GPU is faster than from normal CPU memory. """ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, num_workers=0, pin_memory=False): self._dataset = dataset self._pin_memory = pin_memory if batch_sampler is None: if batch_size is None: raise ValueError("batch_size must be specified unless " \ "batch_sampler is specified") if sampler is None: if shuffle: sampler = _sampler.RandomSampler(len(dataset)) else: sampler = _sampler.SequentialSampler(len(dataset)) elif shuffle: raise ValueError("shuffle must not be specified if sampler is specified") batch_sampler = _sampler.BatchSampler( sampler, batch_size, last_batch if last_batch else 'keep') elif batch_size is not None or shuffle or sampler is not None or \ last_batch is not None: raise ValueError("batch_size, shuffle, sampler and last_batch must " \ "not be specified if batch_sampler is specified.") self._batch_sampler = batch_sampler self._num_workers = num_workers if num_workers >= 0 else 0 if batchify_fn is None: if num_workers > 0: self._batchify_fn = default_mp_batchify_fn else: self._batchify_fn = default_batchify_fn else: self._batchify_fn = batchify_fn def __iter__(self): if self._num_workers == 0: def same_process_iter(): for batch in self._batch_sampler: ret = self._batchify_fn([self._dataset[idx] for idx in batch]) if self._pin_memory: ret = _as_in_context(ret, context.cpu_pinned()) yield ret return same_process_iter() # multi-worker return _MultiWorkerIter(self._num_workers, self._dataset, self._batchify_fn, self._batch_sampler, self._pin_memory) def __len__(self): return len(self._batch_sampler)