diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index a86cf602ff..869e67ac41 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -27,6 +27,7 @@ import multiprocessing import queue from enum import Enum from importlib import import_module +import sys import threading import copy @@ -42,7 +43,8 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched import mindspore.dataset.transforms.py_transforms as py_transforms from . import samplers -from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup +from .iterators import DictIterator, TupleIterator, DummyIterator, SaveOp, Iterator, check_iterator_cleanup, \ + _set_iterator_cleanup from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ check_rename, check_numpyslicesdataset, check_device_send, \ check_take, check_project, check_imagefolderdataset, check_mnist_cifar_dataset, check_manifestdataset, \ @@ -2040,7 +2042,8 @@ class _PythonCallable: except multiprocessing.TimeoutError: continue except KeyboardInterrupt: - self.pool.terminate() + _set_iterator_cleanup() + self.pool.close() self.pool.join() raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") return (None,) @@ -2048,6 +2051,18 @@ class _PythonCallable: return self.py_callable(*args) +class _ExceptHookHandler: + def __init__(self, pool): + self.__pool = pool + sys.excepthook = self.__handler_exception + + def __handler_exception(self, type, value, tb): + logger.error("Uncaught exception: ", exc_info=(type, value, tb)) + if self.__pool is not None: + _set_iterator_cleanup() + self.__pool.terminate() + + class MapDataset(DatasetOp): """ The result of applying the Map operator to the input Dataset. @@ -2124,6 +2139,7 @@ class MapDataset(DatasetOp): callbacks = [callbacks] self.callbacks = callbacks + self.hook = None def get_args(self): args = super().get_args() @@ -2163,6 +2179,7 @@ class MapDataset(DatasetOp): new_op.input_indexs = copy.deepcopy(self._input_indexs, memodict) new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict) new_op.cache = copy.deepcopy(self.cache, memodict) + new_op.hook = copy.deepcopy(self.hook, memodict) new_op.operations = self.operations new_op.dataset_size = self.dataset_size new_op.callbacks = self.callbacks @@ -2203,10 +2220,11 @@ class MapDataset(DatasetOp): # CPP ops remain the same iter_specific_operations.append(op) self.operations = iter_specific_operations + self.hook = _ExceptHookHandler(self.process_pool) def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: - self.process_pool.terminate() + self.process_pool.close() class FilterDataset(DatasetOp):