|
|
|
@ -13,15 +13,15 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Dataset help for minddata dataset"""
|
|
|
|
|
from mindspore._checkparam import check_bool
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.train.parallel_utils import ParallelMode
|
|
|
|
|
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
|
|
|
|
|
_construct_tensor_list, _to_full_shapes, _to_full_tensor
|
|
|
|
|
from mindspore._checkparam import check_bool
|
|
|
|
|
from mindspore.nn.wrap import GetNextSingleOp
|
|
|
|
|
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
|
|
|
|
|
_construct_tensor_list, _to_full_shapes, _to_full_tensor
|
|
|
|
|
from mindspore.train.parallel_utils import ParallelMode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatasetHelper:
|
|
|
|
|
"""
|
|
|
|
|
Help function to use the Minddata dataset.
|
|
|
|
@ -41,9 +41,10 @@ class DatasetHelper:
|
|
|
|
|
>>> for inputs in dataset_helper:
|
|
|
|
|
>>> outputs = network(*inputs)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, first_order_iter=0, dataset_sink_mode=True):
|
|
|
|
|
check_bool(dataset_sink_mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iterclass = _DatasetIterGE
|
|
|
|
|
if not dataset_sink_mode:
|
|
|
|
|
iterclass = _DatasetIterFeed
|
|
|
|
@ -52,24 +53,25 @@ class DatasetHelper:
|
|
|
|
|
iterclass = _DatasetIterMSLoopSink
|
|
|
|
|
else:
|
|
|
|
|
iterclass = _DatasetIterMS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.iter = iterclass(dataset, first_order_iter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
return self.iter.__iter__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# A temp solution for loop sink. Delete later
|
|
|
|
|
def types_shapes(self):
|
|
|
|
|
"""Get the types and shapes from dataset on current config."""
|
|
|
|
|
return self.iter.types_shapes()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def loop_size(self):
|
|
|
|
|
"""Get loop_size for every iteration."""
|
|
|
|
|
return self.iter.loop_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIter:
|
|
|
|
|
"""Base iter for dataset help"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset):
|
|
|
|
|
self.loop_size = 1
|
|
|
|
|
if not hasattr(dataset, '__ME_INITED__'):
|
|
|
|
@ -78,7 +80,7 @@ class _DatasetIter:
|
|
|
|
|
else:
|
|
|
|
|
self.loop_size = dataset.__loop_size__
|
|
|
|
|
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ind = 0
|
|
|
|
|
self.dataset = dataset
|
|
|
|
|
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
|
|
|
|
@ -89,53 +91,57 @@ class _DatasetIter:
|
|
|
|
|
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
|
|
|
|
device_num = _get_device_num()
|
|
|
|
|
self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
self.ind = 0
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __next__(self):
|
|
|
|
|
if self.ind >= self.loop_count:
|
|
|
|
|
raise StopIteration()
|
|
|
|
|
self.ind += 1
|
|
|
|
|
return self.op()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def types_shapes(self):
|
|
|
|
|
return self.dataset_types, self.dataset_shapes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_loop_count(self, dataset):
|
|
|
|
|
loop_count = 1
|
|
|
|
|
if hasattr(dataset, '__loop_size__'):
|
|
|
|
|
loop_size = dataset.__loop_size__
|
|
|
|
|
loop_count = int(dataset.get_dataset_size()/loop_size)
|
|
|
|
|
loop_count = int(dataset.get_dataset_size() / loop_size)
|
|
|
|
|
return loop_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterMSLoopSink(_DatasetIter):
|
|
|
|
|
"""Iter for context (enable_loop_sink=True)"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, first_order_iter):
|
|
|
|
|
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
|
|
|
|
# self.loop_count = self.get_loop_count(dataset)
|
|
|
|
|
loop_size = dataset.__loop_size__ + first_order_iter
|
|
|
|
|
self.loop_count = int(dataset.get_dataset_size()/loop_size) * 2
|
|
|
|
|
|
|
|
|
|
self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2
|
|
|
|
|
|
|
|
|
|
def op():
|
|
|
|
|
return tuple()
|
|
|
|
|
|
|
|
|
|
self.op = op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterMS(_DatasetIter):
|
|
|
|
|
"""Iter for context (enable_loop_sink=False)"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, first_order_order):
|
|
|
|
|
super(_DatasetIterMS, self).__init__(dataset)
|
|
|
|
|
self.loop_count = dataset.get_dataset_size()
|
|
|
|
|
self.loop_size = 1
|
|
|
|
|
queue_name = dataset.__ME_INITED__
|
|
|
|
|
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterGE(_DatasetIter):
|
|
|
|
|
"""Iter for ge"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset):
|
|
|
|
|
super(_DatasetIterGE, self).__init__(dataset)
|
|
|
|
|
self.loop_count = self.get_loop_count(dataset)
|
|
|
|
@ -145,14 +151,16 @@ class _DatasetIterGE(_DatasetIter):
|
|
|
|
|
if self.need_to_full:
|
|
|
|
|
batch_expand_num = _get_device_num()
|
|
|
|
|
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def op():
|
|
|
|
|
return tensor_list_run
|
|
|
|
|
|
|
|
|
|
self.op = op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterFeed:
|
|
|
|
|
"""Iter for feed data"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, first_order_order):
|
|
|
|
|
self.dataset = dataset
|
|
|
|
|
self.device_num = _get_device_num()
|
|
|
|
@ -161,18 +169,18 @@ class _DatasetIterFeed:
|
|
|
|
|
self.repeat_ind = 0
|
|
|
|
|
self.loop_count = dataset.get_dataset_size()
|
|
|
|
|
self.ind = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
|
self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
if self.repeat_ind % self.repeat_count == 0:
|
|
|
|
|
self.iter = self.dataset.__iter__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.repeat_ind += 1
|
|
|
|
|
self.ind = 0
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __next__(self):
|
|
|
|
|
if self.ind >= self.loop_count:
|
|
|
|
|
raise StopIteration()
|