|
|
|
@ -17,6 +17,7 @@
|
|
|
|
|
from abc import abstractmethod
|
|
|
|
|
import copy
|
|
|
|
|
import weakref
|
|
|
|
|
from importlib import import_module
|
|
|
|
|
|
|
|
|
|
from mindspore._c_dataengine import DEPipeline
|
|
|
|
|
from mindspore._c_dataengine import OpName
|
|
|
|
@ -24,11 +25,26 @@ from mindspore._c_dataengine import OpName
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from . import datasets as de
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
context = import_module("mindspore.context")
|
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
|
context = None
|
|
|
|
|
|
|
|
|
|
ITERATORS_LIST = list()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cleanup():
|
|
|
|
|
"""Release all the Iterator."""
|
|
|
|
|
for itr_ref in ITERATORS_LIST:
|
|
|
|
|
if context:
|
|
|
|
|
device_type = context.get_context("device_target")
|
|
|
|
|
if device_type == "GPU":
|
|
|
|
|
itr_ref.release()
|
|
|
|
|
else:
|
|
|
|
|
itr = itr_ref()
|
|
|
|
|
if itr is not None:
|
|
|
|
|
itr.release()
|
|
|
|
|
else:
|
|
|
|
|
itr = itr_ref()
|
|
|
|
|
if itr is not None:
|
|
|
|
|
itr.release()
|
|
|
|
@ -85,6 +101,13 @@ class Iterator:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset):
|
|
|
|
|
if context:
|
|
|
|
|
device_type = context.get_context("device_target")
|
|
|
|
|
if device_type == "GPU":
|
|
|
|
|
ITERATORS_LIST.append(self)
|
|
|
|
|
else:
|
|
|
|
|
ITERATORS_LIST.append(weakref.ref(self))
|
|
|
|
|
else:
|
|
|
|
|
ITERATORS_LIST.append(weakref.ref(self))
|
|
|
|
|
# create a copy of tree and work on it.
|
|
|
|
|
self.dataset = copy.deepcopy(dataset)
|
|
|
|
|