|
|
|
@ -3279,14 +3279,13 @@ class SamplerFn:
|
|
|
|
|
# Event for end of epoch
|
|
|
|
|
if multi_process is True:
|
|
|
|
|
self.eoe = multiprocessing.Event()
|
|
|
|
|
self.eof = multiprocessing.Event()
|
|
|
|
|
else:
|
|
|
|
|
self.eoe = threading.Event()
|
|
|
|
|
self.eof = threading.Event()
|
|
|
|
|
# Create workers
|
|
|
|
|
for _ in range(num_worker):
|
|
|
|
|
if multi_process is True:
|
|
|
|
|
worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof)
|
|
|
|
|
worker = _GeneratorWorkerMp(dataset, self.eoe)
|
|
|
|
|
else:
|
|
|
|
|
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
|
|
|
|
|
worker.daemon = True
|
|
|
|
@ -3327,15 +3326,40 @@ class SamplerFn:
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
self.eoe.set()
|
|
|
|
|
self.eof.set()
|
|
|
|
|
if self.multi_process is False:
|
|
|
|
|
self.eof.set()
|
|
|
|
|
for w in self.workers:
|
|
|
|
|
w.join()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
|
|
|
|
|
def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe):
|
|
|
|
|
"""
|
|
|
|
|
Multiprocessing or multithread generator worker process loop.
|
|
|
|
|
Multiprocessing generator worker process loop
|
|
|
|
|
"""
|
|
|
|
|
while True:
|
|
|
|
|
# Fetch index, block
|
|
|
|
|
try:
|
|
|
|
|
idx = idx_queue.get()
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
|
|
|
|
if idx is None:
|
|
|
|
|
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
|
|
|
|
# Upon receiving None, worker process should check if EOE is set.
|
|
|
|
|
assert eoe.is_set(), ""
|
|
|
|
|
return
|
|
|
|
|
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
|
|
|
|
result = dataset[idx]
|
|
|
|
|
# Send data, block
|
|
|
|
|
try:
|
|
|
|
|
result_queue.put(result)
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
|
|
|
|
del result, idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
|
|
|
|
|
"""
|
|
|
|
|
Multithread generator worker process loop.
|
|
|
|
|
"""
|
|
|
|
|
while True:
|
|
|
|
|
# Fetch index, block
|
|
|
|
@ -3383,7 +3407,7 @@ class _GeneratorWorkerMt(threading.Thread):
|
|
|
|
|
def __init__(self, dataset, eoe, eof):
|
|
|
|
|
self.idx_queue = queue.Queue(16)
|
|
|
|
|
self.res_queue = queue.Queue(16)
|
|
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
|
|
|
|
|
super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
|
|
|
|
|
|
|
|
|
|
def put(self, item):
|
|
|
|
|
"""
|
|
|
|
@ -3403,10 +3427,10 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|
|
|
|
Worker process for multiprocess Generator.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, eoe, eof):
|
|
|
|
|
def __init__(self, dataset, eoe):
|
|
|
|
|
self.idx_queue = multiprocessing.Queue(16)
|
|
|
|
|
self.res_queue = multiprocessing.Queue(16)
|
|
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
|
|
|
|
|
super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe))
|
|
|
|
|
|
|
|
|
|
def put(self, item):
|
|
|
|
|
"""
|
|
|
|
|