From 89f81da336ba11f536c45286544a1bc9854d5458 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Mon, 15 Mar 2021 14:36:59 +0800 Subject: [PATCH] fix map multi-process hang for ctrl+c issue --- mindspore/dataset/engine/datasets.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index d81441183f..515b7634fd 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2195,16 +2195,16 @@ def _pyfunc_worker_exec(index, op_id, mapping, lock, record, *args): """ Internal function for call certain pyfunc in python process. """ - try: - if record: - pid = os.getpid() - with lock: - data = mapping[op_id] - data[1].add(pid) - mapping[op_id] = data - return _GLOBAL_PYFUNC_LIST[index](*args) - except KeyboardInterrupt: - raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") + # Some threads in multiprocess.pool can't process sigint signal, + # and will occur hang problem, so ctrl+c will pass to parent process. + signal.signal(signal.SIGINT, signal.SIG_IGN) + if record: + pid = os.getpid() + with lock: + data = mapping[op_id] + data[1].add(pid) + mapping[op_id] = data + return _GLOBAL_PYFUNC_LIST[index](*args) # PythonCallable wrapper for multiprocess pyfunc @@ -3261,6 +3261,10 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiproces try: idx = idx_queue.get(timeout=1) except KeyboardInterrupt: + if is_multiprocessing: + eof.set() + idx_queue.cancel_join_thread() + result_queue.cancel_join_thread() raise Exception("Generator worker receives KeyboardInterrupt.") except queue.Empty: if eof.is_set(): @@ -3287,6 +3291,10 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiproces try: result_queue.put(result, timeout=5) except KeyboardInterrupt: + if is_multiprocessing: + eof.set() + idx_queue.cancel_join_thread() + result_queue.cancel_join_thread() raise Exception("Generator worker receives KeyboardInterrupt.") except queue.Full: if eof.is_set():