From 977c41db01c65648a4ed0bb40e59dda5540138e4 Mon Sep 17 00:00:00 2001 From: YangLuo Date: Fri, 25 Sep 2020 09:57:08 +0800 Subject: [PATCH] Fix timeout of GeneratorDataset multiprocessing --- mindspore/dataset/engine/datasets.py | 40 ++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index cb6fa72867..9b9c2c967d 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3279,14 +3279,13 @@ class SamplerFn: # Event for end of epoch if multi_process is True: self.eoe = multiprocessing.Event() - self.eof = multiprocessing.Event() else: self.eoe = threading.Event() self.eof = threading.Event() # Create workers for _ in range(num_worker): if multi_process is True: - worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof) + worker = _GeneratorWorkerMp(dataset, self.eoe) else: worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) worker.daemon = True @@ -3327,15 +3326,40 @@ class SamplerFn: def __del__(self): self.eoe.set() - self.eof.set() if self.multi_process is False: + self.eof.set() for w in self.workers: w.join() -def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): +def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe): """ - Multiprocessing or multithread generator worker process loop. + Multiprocessing generator worker process loop + """ + while True: + # Fetch index, block + try: + idx = idx_queue.get() + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + if idx is None: + # When the queue is out of scope from master process, a None item can be fetched from the queue. + # Upon receiving None, worker process should check if EOE is set. + assert eoe.is_set(), "" + return + # Fetch data, any exception from __getitem__ will terminate worker and timeout master process + result = dataset[idx] + # Send data, block + try: + result_queue.put(result) + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + del result, idx + + +def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): + """ + Multithread generator worker process loop. """ while True: # Fetch index, block @@ -3383,7 +3407,7 @@ class _GeneratorWorkerMt(threading.Thread): def __init__(self, dataset, eoe, eof): self.idx_queue = queue.Queue(16) self.res_queue = queue.Queue(16) - super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) + super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) def put(self, item): """ @@ -3403,10 +3427,10 @@ class _GeneratorWorkerMp(multiprocessing.Process): Worker process for multiprocess Generator. """ - def __init__(self, dataset, eoe, eof): + def __init__(self, dataset, eoe): self.idx_queue = multiprocessing.Queue(16) self.res_queue = multiprocessing.Queue(16) - super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) + super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe)) def put(self, item): """