remove manager

pull/14456/head
ms_yan 4 years ago
parent f9d1575c5f
commit 0d8b6a4bf6

@ -1982,17 +1982,21 @@ class BatchDataset(Dataset):
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],)) initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],))
idx = 0 idx = 0
global _OP_NAME global _OP_NAME, _OP_PROCESS, _LOCK
op_id = _OP_NAME[str(self)] op_id = _OP_NAME[str(self)]
_manager = multiprocessing.Manager() process_id = {op_id: [self.num_parallel_workers, set()]}
_op_process = _manager.dict() # obtain process id from multiprocessing.pool
_process_lock = _manager.Lock() for pool in self.process_pool._pool: # pylint: disable=W0212
process_id[op_id][1].add(pool.pid)
with _LOCK:
_OP_PROCESS.update(process_id)
# Wrap per_batch_map into _PythonCallable # Wrap per_batch_map into _PythonCallable
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, op_id, _op_process, _process_lock, self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool)
self.num_parallel_workers, self.process_pool)
self.hook = _ExceptHookHandler() self.hook = _ExceptHookHandler()
atexit.register(_mp_pool_exit_preprocess, _manager) atexit.register(_mp_pool_exit_preprocess)
def __del__(self): def __del__(self):
if hasattr(self, 'process_pool') and self.process_pool is not None: if hasattr(self, 'process_pool') and self.process_pool is not None:
@ -2201,19 +2205,13 @@ def _pyfunc_worker_init(pyfunc_list):
# Pyfunc worker execution function # Pyfunc worker execution function
# All exceptions will be raised to main processes # All exceptions will be raised to main processes
def _pyfunc_worker_exec(index, op_id, mapping, lock, record, *args): def _pyfunc_worker_exec(index, *args):
""" """
Internal function for call certain pyfunc in python process. Internal function for call certain pyfunc in python process.
""" """
# Some threads in multiprocess.pool can't process sigint signal, # Some threads in multiprocess.pool can't process sigint signal,
# and will occur hang problem, so ctrl+c will pass to parent process. # and will occur hang problem, so ctrl+c will pass to parent process.
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN)
if record:
pid = os.getpid()
with lock:
data = mapping[op_id]
data[1].add(pid)
mapping[op_id] = data
return _GLOBAL_PYFUNC_LIST[index](*args) return _GLOBAL_PYFUNC_LIST[index](*args)
@ -2223,40 +2221,20 @@ class _PythonCallable:
Internal Python function wrapper for multiprocessing pyfunc. Internal Python function wrapper for multiprocessing pyfunc.
""" """
def __init__(self, py_callable, idx, op_id, mapping, lock, worker_num, pool=None): def __init__(self, py_callable, idx, pool=None):
# Original Python callable from user. # Original Python callable from user.
self.py_callable = py_callable self.py_callable = py_callable
# Process pool created for current iterator. # Process pool created for current iterator.
self.pool = pool self.pool = pool
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST # Python callable index for subprocess _GLOBAL_PYFUNC_LIST
self.idx = idx self.idx = idx
self.op_id = op_id
self.mapping = mapping
self.lock = lock
self.worker_num = worker_num
self.record = True
self.mapping[op_id] = [self.worker_num, set()]
global _OP_PROCESS, _LOCK
with _LOCK:
_OP_PROCESS.update(self.mapping)
def __call__(self, *args): def __call__(self, *args):
if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212
# This call will send the tensors along with Python callable index to the process pool. # This call will send the tensors along with Python callable index to the process pool.
# Block, yield GIL. Current thread will reacquire GIL once result is returned. # Block, yield GIL. Current thread will reacquire GIL once result is returned.
if self.record: result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args])
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, self.mapping, self.lock,
self.record, *args])
else:
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, None, None, self.record,
*args])
if self.record:
data = self.mapping
if len(data[self.op_id][1]) == self.worker_num:
self.record = False
global _OP_PROCESS, _LOCK
with _LOCK:
_OP_PROCESS.update(data)
# todo this check might be wrong # todo this check might be wrong
while check_iterator_cleanup() is False: while check_iterator_cleanup() is False:
try: try:
@ -2273,15 +2251,13 @@ class _PythonCallable:
return self.py_callable(*args) return self.py_callable(*args)
def _mp_pool_exit_preprocess(manager=None): def _mp_pool_exit_preprocess():
if check_iterator_cleanup() is False: if check_iterator_cleanup() is False:
logger.info("Execution preprocessing process before map exit.") logger.info("Execution preprocessing process before map exit.")
# Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async # Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async
# applied to the multiprocessing task to prevent multiprocessing from hang when exiting # applied to the multiprocessing task to prevent multiprocessing from hang when exiting
_set_iterator_cleanup() _set_iterator_cleanup()
time.sleep(3) time.sleep(3)
if manager is not None:
manager.shutdown()
class _ExceptHookHandler: class _ExceptHookHandler:
@ -2385,26 +2361,29 @@ class MapDataset(Dataset):
# The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
initializer=_pyfunc_worker_init, initargs=(callable_list,)) initializer=_pyfunc_worker_init, initargs=(callable_list,))
# Pass #2 # Pass #2
global _OP_NAME
op_id = _OP_NAME[str(self)]
idx = 0 idx = 0
_manager = multiprocessing.Manager() global _OP_NAME, _OP_PROCESS, _LOCK
_op_process = _manager.dict() op_id = _OP_NAME[str(self)]
_process_lock = _manager.Lock() # obtain process id from multiprocessing.pool
process_id = {op_id: [self.num_parallel_workers, set()]}
for pool in self.process_pool._pool: # pylint: disable=W0212
process_id[op_id][1].add(pool.pid)
with _LOCK:
_OP_PROCESS.update(process_id)
for op in self.operations: for op in self.operations:
# our c transforms is now callable and should not be run in python multithreading # our c transforms is now callable and should not be run in python multithreading
if callable(op) and str(op).find("c_transform") < 0: if callable(op) and str(op).find("c_transform") < 0:
# Wrap Python callable into _PythonCallable # Wrap Python callable into _PythonCallable
iter_specific_operations.append(_PythonCallable(op, idx, op_id, _op_process, _process_lock, iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
self.num_parallel_workers, self.process_pool))
idx += 1 idx += 1
else: else:
# CPP ops remain the same # CPP ops remain the same
iter_specific_operations.append(op) iter_specific_operations.append(op)
self.operations = iter_specific_operations self.operations = iter_specific_operations
self.hook = _ExceptHookHandler() self.hook = _ExceptHookHandler()
atexit.register(_mp_pool_exit_preprocess, _manager) atexit.register(_mp_pool_exit_preprocess)
def __del__(self): def __del__(self):
if hasattr(self, 'process_pool') and self.process_pool is not None: if hasattr(self, 'process_pool') and self.process_pool is not None:

Loading…
Cancel
Save