|
|
|
@ -1989,10 +1989,16 @@ class BatchDataset(Dataset):
|
|
|
|
|
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
|
|
|
|
|
initializer=_pyfunc_worker_init, initargs=([self.per_batch_map],))
|
|
|
|
|
idx = 0
|
|
|
|
|
global _OP_NAME
|
|
|
|
|
op_id = _OP_NAME[str(self)]
|
|
|
|
|
_manager = multiprocessing.Manager()
|
|
|
|
|
_op_process = _manager.dict()
|
|
|
|
|
_process_lock = _manager.Lock()
|
|
|
|
|
# Wrap per_batch_map into _PythonCallable
|
|
|
|
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool)
|
|
|
|
|
self.per_batch_map = _PythonCallable(self.per_batch_map, idx, op_id, _op_process, _process_lock,
|
|
|
|
|
self.num_parallel_workers, self.process_pool)
|
|
|
|
|
self.hook = _ExceptHookHandler()
|
|
|
|
|
atexit.register(_mp_pool_exit_preprocess)
|
|
|
|
|
atexit.register(_mp_pool_exit_preprocess, _manager)
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None:
|
|
|
|
@ -2188,6 +2194,7 @@ class ShuffleDataset(Dataset):
|
|
|
|
|
_GLOBAL_PYFUNC_LIST = []
|
|
|
|
|
_OP_NAME = dict()
|
|
|
|
|
_OP_PROCESS = dict()
|
|
|
|
|
_LOCK = multiprocessing.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Pyfunc worker init function
|
|
|
|
@ -2200,11 +2207,17 @@ def _pyfunc_worker_init(pyfunc_list):
|
|
|
|
|
|
|
|
|
|
# Pyfunc worker execution function
|
|
|
|
|
# All exceptions will be raised to main processes
|
|
|
|
|
def _pyfunc_worker_exec(index, *args):
|
|
|
|
|
def _pyfunc_worker_exec(index, op_id, mapping, lock, record, *args):
|
|
|
|
|
"""
|
|
|
|
|
Internal function for call certain pyfunc in python process.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
if record:
|
|
|
|
|
pid = os.getpid()
|
|
|
|
|
with lock:
|
|
|
|
|
data = mapping[op_id]
|
|
|
|
|
data[1].add(pid)
|
|
|
|
|
mapping[op_id] = data
|
|
|
|
|
return _GLOBAL_PYFUNC_LIST[index](*args)
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt")
|
|
|
|
@ -2216,19 +2229,36 @@ class _PythonCallable:
|
|
|
|
|
Internal Python function wrapper for multiprocessing pyfunc.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, py_callable, idx, pool=None):
|
|
|
|
|
def __init__(self, py_callable, idx, op_id, mapping, lock, worker_num, pool=None):
|
|
|
|
|
# Original Python callable from user.
|
|
|
|
|
self.py_callable = py_callable
|
|
|
|
|
# Process pool created for current iterator.
|
|
|
|
|
self.pool = pool
|
|
|
|
|
# Python callable index for subprocess _GLOBAL_PYFUNC_LIST
|
|
|
|
|
self.idx = idx
|
|
|
|
|
self.op_id = op_id
|
|
|
|
|
self.mapping = mapping
|
|
|
|
|
self.lock = lock
|
|
|
|
|
self.worker_num = worker_num
|
|
|
|
|
self.record = True
|
|
|
|
|
self.mapping[op_id] = [self.worker_num, set()]
|
|
|
|
|
global _OP_PROCESS, _LOCK
|
|
|
|
|
with _LOCK:
|
|
|
|
|
_OP_PROCESS.update(self.mapping)
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args):
|
|
|
|
|
if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212
|
|
|
|
|
# This call will send the tensors along with Python callable index to the process pool.
|
|
|
|
|
# Block, yield GIL. Current thread will reacquire GIL once result is returned.
|
|
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args])
|
|
|
|
|
result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, self.op_id, self.mapping, self.lock,
|
|
|
|
|
self.record, *args])
|
|
|
|
|
if self.record:
|
|
|
|
|
data = self.mapping
|
|
|
|
|
if len(data[self.op_id][1]) == self.worker_num:
|
|
|
|
|
self.record = False
|
|
|
|
|
global _OP_PROCESS, _LOCK
|
|
|
|
|
with _LOCK:
|
|
|
|
|
_OP_PROCESS.update(data)
|
|
|
|
|
# todo this check might be wrong
|
|
|
|
|
while check_iterator_cleanup() is False:
|
|
|
|
|
try:
|
|
|
|
@ -2245,13 +2275,15 @@ class _PythonCallable:
|
|
|
|
|
return self.py_callable(*args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _mp_pool_exit_preprocess():
|
|
|
|
|
def _mp_pool_exit_preprocess(manager=None):
|
|
|
|
|
if check_iterator_cleanup() is False:
|
|
|
|
|
logger.info("Execution preprocessing process before map exit.")
|
|
|
|
|
# Set the iterator_cleanup flag to True before exiting, and wait 3s for all apply_async
|
|
|
|
|
# applied to the multiprocessing task to prevent multiprocessing from hang when exiting
|
|
|
|
|
_set_iterator_cleanup()
|
|
|
|
|
time.sleep(3)
|
|
|
|
|
if manager is not None:
|
|
|
|
|
manager.shutdown()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ExceptHookHandler:
|
|
|
|
@ -2339,6 +2371,7 @@ class MapDataset(Dataset):
|
|
|
|
|
"""
|
|
|
|
|
Per iterator bootstrap callback.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if self.python_multiprocessing:
|
|
|
|
|
iter_specific_operations = []
|
|
|
|
|
callable_list = []
|
|
|
|
@ -2355,19 +2388,25 @@ class MapDataset(Dataset):
|
|
|
|
|
self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers,
|
|
|
|
|
initializer=_pyfunc_worker_init, initargs=(callable_list,))
|
|
|
|
|
# Pass #2
|
|
|
|
|
global _OP_NAME
|
|
|
|
|
op_id = _OP_NAME[str(self)]
|
|
|
|
|
idx = 0
|
|
|
|
|
_manager = multiprocessing.Manager()
|
|
|
|
|
_op_process = _manager.dict()
|
|
|
|
|
_process_lock = _manager.Lock()
|
|
|
|
|
for op in self.operations:
|
|
|
|
|
# our c transforms is now callable and should not be run in python multithreading
|
|
|
|
|
if callable(op) and str(op).find("c_transform") < 0:
|
|
|
|
|
# Wrap Python callable into _PythonCallable
|
|
|
|
|
iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool))
|
|
|
|
|
iter_specific_operations.append(_PythonCallable(op, idx, op_id, _op_process, _process_lock,
|
|
|
|
|
self.num_parallel_workers, self.process_pool))
|
|
|
|
|
idx += 1
|
|
|
|
|
else:
|
|
|
|
|
# CPP ops remain the same
|
|
|
|
|
iter_specific_operations.append(op)
|
|
|
|
|
self.operations = iter_specific_operations
|
|
|
|
|
self.hook = _ExceptHookHandler()
|
|
|
|
|
atexit.register(_mp_pool_exit_preprocess)
|
|
|
|
|
atexit.register(_mp_pool_exit_preprocess, _manager)
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
if hasattr(self, 'process_pool') and self.process_pool is not None:
|
|
|
|
|