|
|
@ -13,13 +13,13 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# ============================================================================
|
|
|
|
# ============================================================================
|
|
|
|
"""Dataset help for minddata dataset"""
|
|
|
|
"""Dataset help for minddata dataset"""
|
|
|
|
from mindspore._checkparam import check_bool
|
|
|
|
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore import context
|
|
|
|
from mindspore.train.parallel_utils import ParallelMode
|
|
|
|
from mindspore._checkparam import check_bool
|
|
|
|
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
|
|
|
|
|
|
|
|
_construct_tensor_list, _to_full_shapes, _to_full_tensor
|
|
|
|
|
|
|
|
from mindspore.nn.wrap import GetNextSingleOp
|
|
|
|
from mindspore.nn.wrap import GetNextSingleOp
|
|
|
|
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
|
|
|
|
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
|
|
|
|
|
|
|
|
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
|
|
|
|
|
|
|
|
_construct_tensor_list, _to_full_shapes, _to_full_tensor
|
|
|
|
|
|
|
|
from mindspore.train.parallel_utils import ParallelMode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatasetHelper:
|
|
|
|
class DatasetHelper:
|
|
|
@ -41,6 +41,7 @@ class DatasetHelper:
|
|
|
|
>>> for inputs in dataset_helper:
|
|
|
|
>>> for inputs in dataset_helper:
|
|
|
|
>>> outputs = network(*inputs)
|
|
|
|
>>> outputs = network(*inputs)
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, first_order_iter=0, dataset_sink_mode=True):
|
|
|
|
def __init__(self, dataset, first_order_iter=0, dataset_sink_mode=True):
|
|
|
|
check_bool(dataset_sink_mode)
|
|
|
|
check_bool(dataset_sink_mode)
|
|
|
|
|
|
|
|
|
|
|
@ -70,6 +71,7 @@ class DatasetHelper:
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIter:
|
|
|
|
class _DatasetIter:
|
|
|
|
"""Base iter for dataset help"""
|
|
|
|
"""Base iter for dataset help"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset):
|
|
|
|
def __init__(self, dataset):
|
|
|
|
self.loop_size = 1
|
|
|
|
self.loop_size = 1
|
|
|
|
if not hasattr(dataset, '__ME_INITED__'):
|
|
|
|
if not hasattr(dataset, '__ME_INITED__'):
|
|
|
@ -107,25 +109,28 @@ class _DatasetIter:
|
|
|
|
loop_count = 1
|
|
|
|
loop_count = 1
|
|
|
|
if hasattr(dataset, '__loop_size__'):
|
|
|
|
if hasattr(dataset, '__loop_size__'):
|
|
|
|
loop_size = dataset.__loop_size__
|
|
|
|
loop_size = dataset.__loop_size__
|
|
|
|
loop_count = int(dataset.get_dataset_size()/loop_size)
|
|
|
|
loop_count = int(dataset.get_dataset_size() / loop_size)
|
|
|
|
return loop_count
|
|
|
|
return loop_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterMSLoopSink(_DatasetIter):
|
|
|
|
class _DatasetIterMSLoopSink(_DatasetIter):
|
|
|
|
"""Iter for context (enable_loop_sink=True)"""
|
|
|
|
"""Iter for context (enable_loop_sink=True)"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, first_order_iter):
|
|
|
|
def __init__(self, dataset, first_order_iter):
|
|
|
|
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
|
|
|
super(_DatasetIterMSLoopSink, self).__init__(dataset)
|
|
|
|
# self.loop_count = self.get_loop_count(dataset)
|
|
|
|
# self.loop_count = self.get_loop_count(dataset)
|
|
|
|
loop_size = dataset.__loop_size__ + first_order_iter
|
|
|
|
loop_size = dataset.__loop_size__ + first_order_iter
|
|
|
|
self.loop_count = int(dataset.get_dataset_size()/loop_size) * 2
|
|
|
|
self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2
|
|
|
|
|
|
|
|
|
|
|
|
def op():
|
|
|
|
def op():
|
|
|
|
return tuple()
|
|
|
|
return tuple()
|
|
|
|
|
|
|
|
|
|
|
|
self.op = op
|
|
|
|
self.op = op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterMS(_DatasetIter):
|
|
|
|
class _DatasetIterMS(_DatasetIter):
|
|
|
|
"""Iter for context (enable_loop_sink=False)"""
|
|
|
|
"""Iter for context (enable_loop_sink=False)"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, first_order_order):
|
|
|
|
def __init__(self, dataset, first_order_order):
|
|
|
|
super(_DatasetIterMS, self).__init__(dataset)
|
|
|
|
super(_DatasetIterMS, self).__init__(dataset)
|
|
|
|
self.loop_count = dataset.get_dataset_size()
|
|
|
|
self.loop_count = dataset.get_dataset_size()
|
|
|
@ -136,6 +141,7 @@ class _DatasetIterMS(_DatasetIter):
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterGE(_DatasetIter):
|
|
|
|
class _DatasetIterGE(_DatasetIter):
|
|
|
|
"""Iter for ge"""
|
|
|
|
"""Iter for ge"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset):
|
|
|
|
def __init__(self, dataset):
|
|
|
|
super(_DatasetIterGE, self).__init__(dataset)
|
|
|
|
super(_DatasetIterGE, self).__init__(dataset)
|
|
|
|
self.loop_count = self.get_loop_count(dataset)
|
|
|
|
self.loop_count = self.get_loop_count(dataset)
|
|
|
@ -148,11 +154,13 @@ class _DatasetIterGE(_DatasetIter):
|
|
|
|
|
|
|
|
|
|
|
|
def op():
|
|
|
|
def op():
|
|
|
|
return tensor_list_run
|
|
|
|
return tensor_list_run
|
|
|
|
|
|
|
|
|
|
|
|
self.op = op
|
|
|
|
self.op = op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterFeed:
|
|
|
|
class _DatasetIterFeed:
|
|
|
|
"""Iter for feed data"""
|
|
|
|
"""Iter for feed data"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, first_order_order):
|
|
|
|
def __init__(self, dataset, first_order_order):
|
|
|
|
self.dataset = dataset
|
|
|
|
self.dataset = dataset
|
|
|
|
self.device_num = _get_device_num()
|
|
|
|
self.device_num = _get_device_num()
|