diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index a7a8389c67..58ff041d35 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3239,21 +3239,19 @@ def _cpp_sampler_fn(sampler, dataset): yield tuple([np.array(x, copy=False) for x in val]) -def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process): +def _cpp_sampler_fn_mp(sampler, sample_fn): """ Multiprocessing generator function wrapper for mappable dataset with cpp sampler. """ indices = sampler.get_indices() - sample_fn = SamplerFn(dataset, num_worker, multi_process) return sample_fn.process(indices) -def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process): +def _py_sampler_fn_mp(sampler, num_samples, sample_fn): """ Multiprocessing generator function wrapper for mappable dataset with Python sampler. """ indices = _fetch_py_sampler_indices(sampler, num_samples) - sample_fn = SamplerFn(dataset, num_worker, multi_process) return sample_fn.process(indices) @@ -3299,17 +3297,21 @@ class SamplerFn: self.multi_process = multi_process # Event for end of epoch if multi_process is True: - self.eoe = multiprocessing.Event() + self.eof = multiprocessing.Event() else: - self.eoe = threading.Event() self.eof = threading.Event() # Create workers for _ in range(num_worker): if multi_process is True: - worker = _GeneratorWorkerMp(dataset, self.eoe) + worker = _GeneratorWorkerMp(dataset, self.eof) + worker.daemon = True + # When multi processes fork a subprocess, the lock of the main process is copied to the subprocess, + # which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase. + # In this phase, the main process is not locked. + worker.start() else: - worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) - worker.daemon = True + worker = _GeneratorWorkerMt(dataset, self.eof) + worker.daemon = True self.workers.append(worker) def process(self, indices): @@ -3317,14 +3319,18 @@ class SamplerFn: The main process, start the child process or child thread, and fill the index queue. Get the result and return. """ + for w in self.workers: + # Check whether the queue of the subprocess is empty. + if not w.queue_empty(): + raise Exception("The queue of the subprocess is not empty.") + # Start all workers + if not w.is_alive(): + w.start() + # Fill initial index queues idx_cursor = 0 idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) - # Start all workers - for w in self.workers: - w.start() - # Fetch results for i in range(len(indices)): # Fetch result and put index @@ -3340,64 +3346,31 @@ class SamplerFn: raise Exception("Generator worker receives KeyboardInterrupt") if idx_cursor < len(indices): idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) - # Set end-of-epoch (eoe) event once all indices are sent - if idx_cursor == len(indices) and not self.eoe.is_set(): - self.eoe.set() yield tuple([np.array(x, copy=False) for x in result]) def __del__(self): - self.eoe.set() - if self.multi_process is False: - self.eof.set() - for w in self.workers: - w.join() - - -def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe): - """ - Multiprocessing generator worker process loop - """ - while True: - # Fetch index, block - try: - idx = idx_queue.get() - except KeyboardInterrupt: - raise Exception("Generator worker receives KeyboardInterrupt") - if idx is None: - # When the queue is out of scope from master process, a None item can be fetched from the queue. - # Upon receiving None, worker process should check if EOE is set. - assert eoe.is_set(), "" - return - # Fetch data, any exception from __getitem__ will terminate worker and timeout master process - result = dataset[idx] - # Send data, block - try: - result_queue.put(result) - except KeyboardInterrupt: - raise Exception("Generator worker receives KeyboardInterrupt") - del result, idx + self.eof.set() -def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): +def _generator_worker_loop(dataset, idx_queue, result_queue, eof): """ - Multithread generator worker process loop. + Multithread or multiprocess generator worker process loop. """ while True: # Fetch index, block try: - # Index is generated very fast, so the timeout is very short - idx = idx_queue.get(timeout=0.01) + idx = idx_queue.get(timeout=1) except KeyboardInterrupt: raise Exception("Generator worker receives KeyboardInterrupt") except queue.Empty: - if eof.is_set() or eoe.is_set(): + if eof.is_set(): return - # If end-of-epoch (eoe) or end-of-file (eof) is not set, continue to get data from idx_queue + # If end-of-file (eof) is not set, continue to get data from idx_queue continue if idx is None: # When the queue is out of scope from master process, a None item can be fetched from the queue. - # Upon receiving None, worker process should check if EOE is set. - assert eoe.is_set(), "" + # Upon receiving None, worker process should check if eof is set. + assert eof.is_set(), "" return if eof.is_set(): return @@ -3416,8 +3389,6 @@ def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof): continue break del result, idx - if eoe.is_set() and idx_queue.empty(): - return class _GeneratorWorkerMt(threading.Thread): @@ -3425,10 +3396,10 @@ class _GeneratorWorkerMt(threading.Thread): Worker process for multithread Generator. """ - def __init__(self, dataset, eoe, eof): + def __init__(self, dataset, eof): self.idx_queue = queue.Queue(16) self.res_queue = queue.Queue(16) - super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) def put(self, item): """ @@ -3442,16 +3413,25 @@ class _GeneratorWorkerMt(threading.Thread): """ return self.res_queue.get(timeout=30) + def queue_empty(self): + if not self.idx_queue.empty(): + logger.error("idx_queue is not empty") + return False + if not self.res_queue.empty(): + logger.error("res_queue is not empty") + return False + return True + class _GeneratorWorkerMp(multiprocessing.Process): """ Worker process for multiprocess Generator. """ - def __init__(self, dataset, eoe): + def __init__(self, dataset, eof): self.idx_queue = multiprocessing.Queue(16) self.res_queue = multiprocessing.Queue(16) - super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe)) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) def put(self, item): """ @@ -3467,6 +3447,15 @@ class _GeneratorWorkerMp(multiprocessing.Process): # when we run too many iterators with infinite epoch(num_epoch=-1) return self.res_queue.get(timeout=30) + def queue_empty(self): + if not self.idx_queue.empty(): + logger.error("idx_queue is not empty") + return False + if not self.res_queue.empty(): + logger.error("res_queue is not empty") + return False + return True + def __del__(self): # Try to destruct here, sometimes the class itself will be destructed in advance, # so "self" will be a NoneType @@ -3657,16 +3646,14 @@ class GeneratorDataset(MappableDataset): sampler_instance.set_num_rows(len(self.source)) sampler_instance.initialize() if new_op.num_parallel_workers > 1: - new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, - new_op.num_parallel_workers, - self.python_multiprocessing)) + sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) + new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn)) else: new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) else: if new_op.num_parallel_workers > 1: - new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, - new_op.num_parallel_workers, - self.python_multiprocessing)) + sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing) + new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn)) else: new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) else: