!858 Fix gpu issue

Merge pull request !858 from xiefangqi/md_fix_gpu_issue
pull/858/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 05676676e9

@ -17,6 +17,7 @@
from abc import abstractmethod from abc import abstractmethod
import copy import copy
import weakref import weakref
from importlib import import_module
from mindspore._c_dataengine import DEPipeline from mindspore._c_dataengine import DEPipeline
from mindspore._c_dataengine import OpName from mindspore._c_dataengine import OpName
@ -24,14 +25,29 @@ from mindspore._c_dataengine import OpName
from mindspore import log as logger from mindspore import log as logger
from . import datasets as de from . import datasets as de
try:
context = import_module("mindspore.context")
except ModuleNotFoundError:
context = None
ITERATORS_LIST = list() ITERATORS_LIST = list()
def _cleanup(): def _cleanup():
"""Release all the Iterator."""
for itr_ref in ITERATORS_LIST: for itr_ref in ITERATORS_LIST:
itr = itr_ref() if context:
if itr is not None: device_type = context.get_context("device_target")
itr.release() if device_type == "GPU":
itr_ref.release()
else:
itr = itr_ref()
if itr is not None:
itr.release()
else:
itr = itr_ref()
if itr is not None:
itr.release()
def alter_tree(node): def alter_tree(node):
@ -85,7 +101,14 @@ class Iterator:
""" """
def __init__(self, dataset): def __init__(self, dataset):
ITERATORS_LIST.append(weakref.ref(self)) if context:
device_type = context.get_context("device_target")
if device_type == "GPU":
ITERATORS_LIST.append(self)
else:
ITERATORS_LIST.append(weakref.ref(self))
else:
ITERATORS_LIST.append(weakref.ref(self))
# create a copy of tree and work on it. # create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset) self.dataset = copy.deepcopy(dataset)
self.dataset = alter_tree(self.dataset) self.dataset = alter_tree(self.dataset)

Loading…
Cancel
Save