From 0d8b6a4bf6830e207d9e7a07da9f8695234fb9af Mon Sep 17 00:00:00 2001 From: ms_yan Date: Tue, 30 Mar 2021 21:27:38 +0800 Subject: [PATCH] remove manager --- mindspore/dataset/engine/datasets.py | 75 ++++++++++------------------ 1 file changed, 27 insertions(+), 48 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 1a00ee7d50..259aa19be1 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1982,17 +1982,21 @@ class BatchDataset(Dataset): # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],)) + idx = 0 - global _OP_NAME + global _OP_NAME, _OP_PROCESS, _LOCK op_id = _OP_NAME[str(self)] - _manager = multiprocessing.Manager() - _op_process = _manager.dict() - _process_lock = _manager.Lock() + process_id = {op_id: [self.num_parallel_workers, set()]} + # obtain process id from multiprocessing.pool + for pool in self.process_pool._pool: # pylint: disable=W0212 + process_id[op_id][1].add(pool.pid) + with _LOCK: + _OP_PROCESS.update(process_id) + # Wrap per_batch_map into _PythonCallable - self.per_batch_map = _PythonCallable(self.per_batch_map, idx, op_id, _op_process, _process_lock, - self.num_parallel_workers, self.process_pool) + self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) self.hook = _ExceptHookHandler() - atexit.register(_mp_pool_exit_preprocess, _manager) + atexit.register(_mp_pool_exit_preprocess) def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: @@ -2201,19 +2205,13 @@ def _pyfunc_worker_init(pyfunc_list): # Pyfunc worker execution function # All exceptions will be raised to main processes -def _pyfunc_worker_exec(index, op_id, mapping, lock, record, *args): +def _pyfunc_worker_exec(index, *args): """ Internal function for call certain pyfunc in python process. """ # Some threads in multiprocess.pool can't process sigint signal, # and will occur hang problem, so ctrl+c will pass to parent process. signal.signal(signal.SIGINT, signal.SIG_IGN) - if record: - pid = os.getpid() - with lock: - data = mapping[op_id] - data[1].add(pid) - mapping[op_id] = data return _GLOBAL_PYFUNC_LIST[index](*args) @@ -2223,40 +2221,20 @@ class _PythonCallable: Internal Python function wrapper for multiprocessing pyfunc. """ - def __init__(self, py_callable, idx, op_id, mapping, lock, worker_num, pool=None): + def __init__(self, py_callable, idx, pool=None): # Original Python callable from user. self.py_callable = py_callable # Process pool created for current iterator. self.pool = pool # Python callable index for subprocess _GLOBAL_PYFUNC_LIST self.idx = idx - self.op_id = op_id - self.mapping = mapping - self.lock = lock - self.worker_num = worker_num - self.record = True - self.mapping[op_id] = [self.worker_num, set()] - global _OP_PROCESS, _LOCK - with _LOCK: - _OP_PROCESS.update(self.mapping) def __call__(self, *args): if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 # This call will send the tensors along with Python callable index to the process pool. # Block, yield GIL. Current thread will reacquire GIL once result is returned. - if self.record: - result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, self.mapping, self.lock, - self.record, *args]) - else: - result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, None, None, self.record, - *args]) - if self.record: - data = self.mapping - if len(data[self.op_id][1]) == self.worker_num: - self.record = False - global _OP_PROCESS, _LOCK - with _LOCK: - _OP_PROCESS.update(data) + result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) + # todo this check might be wrong while check_iterator_cleanup() is False: try: @@ -2273,15 +2251,13 @@ class _PythonCallable: return self.py_callable(*args) -def _mp_pool_exit_preprocess(manager=None): +def _mp_pool_exit_preprocess(): if check_iterator_cleanup() is False: logger.info("Execution preprocessing process before map exit.") # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async # applied to the multiprocessing task to prevent multiprocessing from hang when exiting _set_iterator_cleanup() time.sleep(3) - if manager is not None: - manager.shutdown() class _ExceptHookHandler: @@ -2385,26 +2361,29 @@ class MapDataset(Dataset): # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, initializer=_pyfunc_worker_init, initargs=(callable_list,)) + # Pass #2 - global _OP_NAME - op_id = _OP_NAME[str(self)] idx = 0 - _manager = multiprocessing.Manager() - _op_process = _manager.dict() - _process_lock = _manager.Lock() + global _OP_NAME, _OP_PROCESS, _LOCK + op_id = _OP_NAME[str(self)] + # obtain process id from multiprocessing.pool + process_id = {op_id: [self.num_parallel_workers, set()]} + for pool in self.process_pool._pool: # pylint: disable=W0212 + process_id[op_id][1].add(pool.pid) + with _LOCK: + _OP_PROCESS.update(process_id) for op in self.operations: # our c transforms is now callable and should not be run in python multithreading if callable(op) and str(op).find("c_transform") < 0: # Wrap Python callable into _PythonCallable - iter_specific_operations.append(_PythonCallable(op, idx, op_id, _op_process, _process_lock, - self.num_parallel_workers, self.process_pool)) + iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) idx += 1 else: # CPP ops remain the same iter_specific_operations.append(op) self.operations = iter_specific_operations self.hook = _ExceptHookHandler() - atexit.register(_mp_pool_exit_preprocess, _manager) + atexit.register(_mp_pool_exit_preprocess) def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: