|
|
|
@ -347,6 +347,92 @@ class _DataLoaderIterSingleProcess(_DataLoaderIterBase):
|
|
|
|
|
return self.__next__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE(chenweihang): _worker_loop must be top level method to be pickled
|
|
|
|
|
def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
|
|
|
|
|
collate_fn, init_fn, worker_id, num_workers,
|
|
|
|
|
use_shared_memory):
|
|
|
|
|
try:
|
|
|
|
|
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
|
|
|
|
|
# some shared memory objects may have been applied for but have not yet
|
|
|
|
|
# been put into the inter-process Queue. This part of the object needs
|
|
|
|
|
# to be cleaned up when the process ends.
|
|
|
|
|
CleanupFuncRegistrar.register(_cleanup_mmap)
|
|
|
|
|
|
|
|
|
|
# set signal handler
|
|
|
|
|
core._set_process_signal_handler()
|
|
|
|
|
|
|
|
|
|
global _worker_info
|
|
|
|
|
_worker_info = WorkerInfo(
|
|
|
|
|
id=worker_id, num_workers=num_workers, dataset=dataset)
|
|
|
|
|
|
|
|
|
|
init_exception = None
|
|
|
|
|
try:
|
|
|
|
|
if init_fn is not None:
|
|
|
|
|
init_fn(worker_id)
|
|
|
|
|
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
|
|
|
|
|
collate_fn, True)
|
|
|
|
|
except:
|
|
|
|
|
init_exception = Exception("init_fn failed in worker {}: " \
|
|
|
|
|
"{}".format(worker_id, sys.exc_info()))
|
|
|
|
|
|
|
|
|
|
iterator_drained = False
|
|
|
|
|
parent_watch_dog = ParentWatchDog()
|
|
|
|
|
|
|
|
|
|
while parent_watch_dog.is_alive():
|
|
|
|
|
try:
|
|
|
|
|
data = indices_queue.get(MP_INDICES_CHECK_INTERVAL)
|
|
|
|
|
except queue.Empty:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# None as poison piil, so worker event should be set
|
|
|
|
|
if data is None:
|
|
|
|
|
assert done_event.is_set() or iterator_drained, \
|
|
|
|
|
"get None when worker done_event set"
|
|
|
|
|
break
|
|
|
|
|
# If worker done event is set but get still get data in
|
|
|
|
|
# indices_queue, remaining data should be get and skipped.
|
|
|
|
|
if done_event.is_set() or iterator_drained:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
idx, indices = data
|
|
|
|
|
try:
|
|
|
|
|
if init_exception is not None:
|
|
|
|
|
batch = init_exception
|
|
|
|
|
init_exception = None
|
|
|
|
|
else:
|
|
|
|
|
batch = fetcher.fetch(indices)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if isinstance(
|
|
|
|
|
e, StopIteration) and dataset_kind == _DatasetKind.ITER:
|
|
|
|
|
out_queue.put(_IterableDatasetStopIteration(worker_id))
|
|
|
|
|
iterator_drained = True
|
|
|
|
|
else:
|
|
|
|
|
out_queue.put((idx, e))
|
|
|
|
|
else:
|
|
|
|
|
if use_shared_memory:
|
|
|
|
|
# FIXME(dkp): _convert_to_tensor_list only support np.array
|
|
|
|
|
# list now, should support paddle.Tensor list
|
|
|
|
|
if isinstance(batch[0][0], paddle.Tensor):
|
|
|
|
|
np_batch = []
|
|
|
|
|
for sample in batch:
|
|
|
|
|
np_batch.append([s.numpy() for s in sample])
|
|
|
|
|
batch = np_batch
|
|
|
|
|
|
|
|
|
|
tensor_list = core._convert_to_tensor_list(batch)
|
|
|
|
|
out_queue.put((idx, tensor_list))
|
|
|
|
|
core._remove_tensor_list_mmap_fds(tensor_list)
|
|
|
|
|
else:
|
|
|
|
|
out_queue.put((idx, batch))
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
|
|
|
|
|
pass
|
|
|
|
|
except:
|
|
|
|
|
six.reraise(*sys.exc_info())
|
|
|
|
|
finally:
|
|
|
|
|
if use_shared_memory:
|
|
|
|
|
_cleanup_mmap()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
|
|
|
|
|
def __init__(self, loader):
|
|
|
|
|
super(_DataLoaderIterMultiProcess, self).__init__(loader)
|
|
|
|
@ -404,11 +490,11 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
|
|
|
|
|
indices_queue = multiprocessing.Queue()
|
|
|
|
|
self._indices_queues.append(indices_queue)
|
|
|
|
|
worker = multiprocessing.Process(
|
|
|
|
|
target=self._worker_loop,
|
|
|
|
|
target=_worker_loop,
|
|
|
|
|
args=(self._dataset, self._dataset_kind, indices_queue,
|
|
|
|
|
self._data_queue, self._workers_done_event,
|
|
|
|
|
self._collate_fn, self._worker_init_fn, i,
|
|
|
|
|
self._num_workers))
|
|
|
|
|
self._num_workers, self._use_shared_memory))
|
|
|
|
|
worker.daemon = True
|
|
|
|
|
worker.start()
|
|
|
|
|
self._workers.append(worker)
|
|
|
|
@ -483,90 +569,6 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
|
|
|
|
|
self._blocking_queue.kill()
|
|
|
|
|
logging.error("DataLoader reader thread raised an exception!")
|
|
|
|
|
|
|
|
|
|
def _worker_loop(self, dataset, dataset_kind, indices_queue, out_queue,
|
|
|
|
|
done_event, collate_fn, init_fn, worker_id, num_workers):
|
|
|
|
|
try:
|
|
|
|
|
# NOTE: [ mmap files clear ] When the child process exits unexpectedly,
|
|
|
|
|
# some shared memory objects may have been applied for but have not yet
|
|
|
|
|
# been put into the inter-process Queue. This part of the object needs
|
|
|
|
|
# to be cleaned up when the process ends.
|
|
|
|
|
CleanupFuncRegistrar.register(_cleanup_mmap)
|
|
|
|
|
|
|
|
|
|
# set signal handler
|
|
|
|
|
core._set_process_signal_handler()
|
|
|
|
|
|
|
|
|
|
global _worker_info
|
|
|
|
|
_worker_info = WorkerInfo(
|
|
|
|
|
id=worker_id, num_workers=num_workers, dataset=dataset)
|
|
|
|
|
|
|
|
|
|
init_exception = None
|
|
|
|
|
try:
|
|
|
|
|
if init_fn is not None:
|
|
|
|
|
init_fn(worker_id)
|
|
|
|
|
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset,
|
|
|
|
|
collate_fn, True)
|
|
|
|
|
except:
|
|
|
|
|
init_exception = Exception("init_fn failed in worker {}: " \
|
|
|
|
|
"{}".format(worker_id, sys.exc_info()))
|
|
|
|
|
|
|
|
|
|
iterator_drained = False
|
|
|
|
|
parent_watch_dog = ParentWatchDog()
|
|
|
|
|
|
|
|
|
|
while parent_watch_dog.is_alive():
|
|
|
|
|
try:
|
|
|
|
|
data = indices_queue.get(MP_INDICES_CHECK_INTERVAL)
|
|
|
|
|
except queue.Empty:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# None as poison piil, so worker event should be set
|
|
|
|
|
if data is None:
|
|
|
|
|
assert done_event.is_set() or iterator_drained, \
|
|
|
|
|
"get None when worker done_event set"
|
|
|
|
|
break
|
|
|
|
|
# If worker done event is set but get still get data in
|
|
|
|
|
# indices_queue, remaining data should be get and skipped.
|
|
|
|
|
if done_event.is_set() or iterator_drained:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
idx, indices = data
|
|
|
|
|
try:
|
|
|
|
|
if init_exception is not None:
|
|
|
|
|
batch = init_exception
|
|
|
|
|
init_exception = None
|
|
|
|
|
else:
|
|
|
|
|
batch = fetcher.fetch(indices)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if isinstance(
|
|
|
|
|
e,
|
|
|
|
|
StopIteration) and dataset_kind == _DatasetKind.ITER:
|
|
|
|
|
out_queue.put(_IterableDatasetStopIteration(worker_id))
|
|
|
|
|
iterator_drained = True
|
|
|
|
|
else:
|
|
|
|
|
out_queue.put((idx, e))
|
|
|
|
|
else:
|
|
|
|
|
if self._use_shared_memory:
|
|
|
|
|
# FIXME(dkp): _convert_to_tensor_list only support np.array
|
|
|
|
|
# list now, should support paddle.Tensor list
|
|
|
|
|
if isinstance(batch[0][0], paddle.Tensor):
|
|
|
|
|
np_batch = []
|
|
|
|
|
for sample in batch:
|
|
|
|
|
np_batch.append([s.numpy() for s in sample])
|
|
|
|
|
batch = np_batch
|
|
|
|
|
|
|
|
|
|
tensor_list = core._convert_to_tensor_list(batch)
|
|
|
|
|
out_queue.put((idx, tensor_list))
|
|
|
|
|
core._remove_tensor_list_mmap_fds(tensor_list)
|
|
|
|
|
else:
|
|
|
|
|
out_queue.put((idx, batch))
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
# NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
|
|
|
|
|
pass
|
|
|
|
|
except:
|
|
|
|
|
six.reraise(*sys.exc_info())
|
|
|
|
|
finally:
|
|
|
|
|
if self._use_shared_memory:
|
|
|
|
|
_cleanup_mmap()
|
|
|
|
|
|
|
|
|
|
def _thread_loop(self):
|
|
|
|
|
while not self._thread_done_event.is_set():
|
|
|
|
|
batch = self._get_data()
|
|
|
|
|