Modify the processing method of multiple processes in the GeneratorDataset:

1. Start the child process in the init phase
2. At the beginning of each epoch, the child process is not recreated, but the child process created at the beginning is used
pull/7600/head
heleiwang 4 years ago
parent ee1865605d
commit 4f946bc54b

@ -3239,21 +3239,19 @@ def _cpp_sampler_fn(sampler, dataset):
yield tuple([np.array(x, copy=False) for x in val])
def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process):
def _cpp_sampler_fn_mp(sampler, sample_fn):
"""
Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
"""
indices = sampler.get_indices()
sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices)
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process):
def _py_sampler_fn_mp(sampler, num_samples, sample_fn):
"""
Multiprocessing generator function wrapper for mappable dataset with Python sampler.
"""
indices = _fetch_py_sampler_indices(sampler, num_samples)
sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices)
@ -3299,17 +3297,21 @@ class SamplerFn:
self.multi_process = multi_process
# Event for end of epoch
if multi_process is True:
self.eoe = multiprocessing.Event()
self.eof = multiprocessing.Event()
else:
self.eoe = threading.Event()
self.eof = threading.Event()
# Create workers
for _ in range(num_worker):
if multi_process is True:
worker = _GeneratorWorkerMp(dataset, self.eoe)
worker = _GeneratorWorkerMp(dataset, self.eof)
worker.daemon = True
# When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
# which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase.
# In this phase, the main process is not locked.
worker.start()
else:
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
worker.daemon = True
worker = _GeneratorWorkerMt(dataset, self.eof)
worker.daemon = True
self.workers.append(worker)
def process(self, indices):
@ -3317,14 +3319,18 @@ class SamplerFn:
The main process, start the child process or child thread, and fill the index queue.
Get the result and return.
"""
for w in self.workers:
# Check whether the queue of the subprocess is empty.
if not w.queue_empty():
raise Exception("The queue of the subprocess is not empty.")
# Start all workers
if not w.is_alive():
w.start()
# Fill initial index queues
idx_cursor = 0
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
# Start all workers
for w in self.workers:
w.start()
# Fetch results
for i in range(len(indices)):
# Fetch result and put index
@ -3340,64 +3346,31 @@ class SamplerFn:
raise Exception("Generator worker receives KeyboardInterrupt")
if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
# Set end-of-epoch (eoe) event once all indices are sent
if idx_cursor == len(indices) and not self.eoe.is_set():
self.eoe.set()
yield tuple([np.array(x, copy=False) for x in result])
def __del__(self):
self.eoe.set()
if self.multi_process is False:
self.eof.set()
for w in self.workers:
w.join()
def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe):
"""
Multiprocessing generator worker process loop
"""
while True:
# Fetch index, block
try:
idx = idx_queue.get()
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe.is_set(), ""
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx]
# Send data, block
try:
result_queue.put(result)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
del result, idx
self.eof.set()
def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
"""
Multithread generator worker process loop.
Multithread or multiprocess generator worker process loop.
"""
while True:
# Fetch index, block
try:
# Index is generated very fast, so the timeout is very short
idx = idx_queue.get(timeout=0.01)
idx = idx_queue.get(timeout=1)
except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt")
except queue.Empty:
if eof.is_set() or eoe.is_set():
if eof.is_set():
return
# If end-of-epoch (eoe) or end-of-file (eof) is not set, continue to get data from idx_queue
# If end-of-file (eof) is not set, continue to get data from idx_queue
continue
if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe.is_set(), ""
# Upon receiving None, worker process should check if eof is set.
assert eof.is_set(), ""
return
if eof.is_set():
return
@ -3416,8 +3389,6 @@ def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
continue
break
del result, idx
if eoe.is_set() and idx_queue.empty():
return
class _GeneratorWorkerMt(threading.Thread):
@ -3425,10 +3396,10 @@ class _GeneratorWorkerMt(threading.Thread):
Worker process for multithread Generator.
"""
def __init__(self, dataset, eoe, eof):
def __init__(self, dataset, eof):
self.idx_queue = queue.Queue(16)
self.res_queue = queue.Queue(16)
super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
def put(self, item):
"""
@ -3442,16 +3413,25 @@ class _GeneratorWorkerMt(threading.Thread):
"""
return self.res_queue.get(timeout=30)
def queue_empty(self):
if not self.idx_queue.empty():
logger.error("idx_queue is not empty")
return False
if not self.res_queue.empty():
logger.error("res_queue is not empty")
return False
return True
class _GeneratorWorkerMp(multiprocessing.Process):
"""
Worker process for multiprocess Generator.
"""
def __init__(self, dataset, eoe):
def __init__(self, dataset, eof):
self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
def put(self, item):
"""
@ -3467,6 +3447,15 @@ class _GeneratorWorkerMp(multiprocessing.Process):
# when we run too many iterators with infinite epoch(num_epoch=-1)
return self.res_queue.get(timeout=30)
def queue_empty(self):
if not self.idx_queue.empty():
logger.error("idx_queue is not empty")
return False
if not self.res_queue.empty():
logger.error("res_queue is not empty")
return False
return True
def __del__(self):
# Try to destruct here, sometimes the class itself will be destructed in advance,
# so "self" will be a NoneType
@ -3657,16 +3646,14 @@ class GeneratorDataset(MappableDataset):
sampler_instance.set_num_rows(len(self.source))
sampler_instance.initialize()
if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn))
else:
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
else:
if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn))
else:
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
else:

Loading…
Cancel
Save