|
|
|
@ -36,6 +36,7 @@ from .. import core, layers
|
|
|
|
|
from ..framework import in_dygraph_mode
|
|
|
|
|
from ..multiprocess_utils import CleanupFuncRegistrar, _cleanup_mmap, _set_SIGCHLD_handler
|
|
|
|
|
from .fetcher import _IterableDatasetFetcher, _MapDatasetFetcher
|
|
|
|
|
from .batch_sampler import _InfiniteIterableSampler
|
|
|
|
|
|
|
|
|
|
__all__ = ['get_worker_info']
|
|
|
|
|
|
|
|
|
@ -100,11 +101,13 @@ class _DatasetKind(object):
|
|
|
|
|
ITER = 1
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def create_fetcher(kind, dataset, collate_fn, drop_last):
|
|
|
|
|
def create_fetcher(kind, dataset, auto_collate_batch, collate_fn, drop_last):
|
|
|
|
|
if kind == _DatasetKind.MAP:
|
|
|
|
|
return _MapDatasetFetcher(dataset, collate_fn, drop_last)
|
|
|
|
|
return _MapDatasetFetcher(dataset, auto_collate_batch,
|
|
|
|
|
collate_fn, drop_last)
|
|
|
|
|
elif kind == _DatasetKind.ITER:
|
|
|
|
|
return _IterableDatasetFetcher(dataset, collate_fn, drop_last)
|
|
|
|
|
return _IterableDatasetFetcher(dataset, auto_collate_batch,
|
|
|
|
|
collate_fn, drop_last)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError("unknown Dataset kind {}".format(kind))
|
|
|
|
|
|
|
|
|
@ -221,8 +224,7 @@ class _DataLoaderIterBase(object):
|
|
|
|
|
self._places = loader.places
|
|
|
|
|
self._return_list = loader.return_list
|
|
|
|
|
self._batch_sampler = loader.batch_sampler
|
|
|
|
|
self._sampler_iter = iter(loader.batch_sampler)
|
|
|
|
|
self._collate_fn = loader.collate_fn or default_collate_fn
|
|
|
|
|
self._auto_collate_batch = loader.auto_collate_batch
|
|
|
|
|
self._num_workers = loader.num_workers
|
|
|
|
|
self._use_buffer_reader = loader.use_buffer_reader
|
|
|
|
|
self._use_shared_memory = loader.use_shared_memory
|
|
|
|
@ -231,6 +233,16 @@ class _DataLoaderIterBase(object):
|
|
|
|
|
self._dataset_kind = loader.dataset_kind
|
|
|
|
|
self._pin_memory = loader.pin_memory
|
|
|
|
|
|
|
|
|
|
if self._auto_collate_batch:
|
|
|
|
|
self._sampler_iter = iter(loader.batch_sampler)
|
|
|
|
|
self._collate_fn = loader.collate_fn or default_collate_fn
|
|
|
|
|
else:
|
|
|
|
|
if self._dataset_kind == _DatasetKind.MAP:
|
|
|
|
|
self._sampler_iter = iter(list(range(len(self._dataset))))
|
|
|
|
|
else:
|
|
|
|
|
self._sampler_iter = iter(_InfiniteIterableSampler(self._dataset, 1))
|
|
|
|
|
self._collate_fn = loader.collate_fn
|
|
|
|
|
|
|
|
|
|
# LoDTensorBlockingQueue instance for create_py_reader and a thread
|
|
|
|
|
# to put mini-batch data to self._blocking_queue, mini-batch data
|
|
|
|
|
# will be get from:
|
|
|
|
@ -257,7 +269,8 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
|
|
|
|
|
super(_DataLoaderIterSingleProcess, self).__init__(loader)
|
|
|
|
|
|
|
|
|
|
self._dataset_fetcher = _DatasetKind.create_fetcher(
|
|
|
|
|
self._dataset_kind, self._dataset, self._collate_fn, True)
|
|
|
|
|
self._dataset_kind, self._dataset, self._auto_collate_batch,
|
|
|
|
|
self._collate_fn, True)
|
|
|
|
|
|
|
|
|
|
# NOTE: len(self._places) batch data compose as an output
|
|
|
|
|
# iteration, set blocking_queue can cache 2 iteration datas
|
|
|
|
@ -367,7 +380,7 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
|
|
|
|
|
|
|
|
|
|
# NOTE(chenweihang): _worker_loop must be top level method to be pickled
|
|
|
|
|
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
|
|
|
|
|
collate_fn, init_fn, worker_id, num_workers,
|
|
|
|
|
auto_collate_batch, collate_fn, init_fn, worker_id, num_workers,
|
|
|
|
|
use_shared_memory):
|
|
|
|
|
try:
|
|
|
|
|
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
|
|
|
|
@ -388,7 +401,7 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
|
|
|
|
|
if init_fn is not None:
|
|
|
|
|
init_fn(worker_id)
|
|
|
|
|
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
|
|
|
|
|
collate_fn, True)
|
|
|
|
|
auto_collate_batch, collate_fn, True)
|
|
|
|
|
except:
|
|
|
|
|
init_exception = Exception("init_fn failed in worker {}: " \
|
|
|
|
|
"{}".format(worker_id, sys.exc_info()))
|
|
|
|
@ -511,8 +524,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
|
|
|
|
|
target=_worker_loop,
|
|
|
|
|
args=(self._dataset, self._dataset_kind, indices_queue,
|
|
|
|
|
self._data_queue, self._workers_done_event,
|
|
|
|
|
self._collate_fn, self._worker_init_fn, i,
|
|
|
|
|
self._num_workers, self._use_shared_memory))
|
|
|
|
|
self._auto_collate_batch, self._collate_fn,
|
|
|
|
|
self._worker_init_fn, i, self._num_workers,
|
|
|
|
|
self._use_shared_memory))
|
|
|
|
|
worker.daemon = True
|
|
|
|
|
worker.start()
|
|
|
|
|
self._workers.append(worker)
|
|
|
|
|