diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index ebee204b37..2cf95aa086 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -17,6 +17,7 @@ from abc import abstractmethod import copy import weakref +from importlib import import_module from mindspore._c_dataengine import DEPipeline from mindspore._c_dataengine import OpName @@ -24,14 +25,29 @@ from mindspore._c_dataengine import OpName from mindspore import log as logger from . import datasets as de +try: + context = import_module("mindspore.context") +except ModuleNotFoundError: + context = None + ITERATORS_LIST = list() def _cleanup(): + """Release all the Iterator.""" for itr_ref in ITERATORS_LIST: - itr = itr_ref() - if itr is not None: - itr.release() + if context: + device_type = context.get_context("device_target") + if device_type == "GPU": + itr_ref.release() + else: + itr = itr_ref() + if itr is not None: + itr.release() + else: + itr = itr_ref() + if itr is not None: + itr.release() def alter_tree(node): @@ -85,7 +101,14 @@ class Iterator: """ def __init__(self, dataset): - ITERATORS_LIST.append(weakref.ref(self)) + if context: + device_type = context.get_context("device_target") + if device_type == "GPU": + ITERATORS_LIST.append(self) + else: + ITERATORS_LIST.append(weakref.ref(self)) + else: + ITERATORS_LIST.append(weakref.ref(self)) # create a copy of tree and work on it. self.dataset = copy.deepcopy(dataset) self.dataset = alter_tree(self.dataset)