|
|
|
@ -24,13 +24,18 @@ from ..nn.wrap import GetNextSingleOp
|
|
|
|
|
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _send_data(dataset):
|
|
|
|
|
def _send_data(dataset, epoch_num):
|
|
|
|
|
"""Engine dataset to write data to tdt queue."""
|
|
|
|
|
if not hasattr(dataset, '__has_sent__'):
|
|
|
|
|
exec_dataset = dataset.__TRANSFER_DATASET__
|
|
|
|
|
exec_dataset.send()
|
|
|
|
|
exec_dataset.send(epoch_num)
|
|
|
|
|
dataset.__has_sent__ = True
|
|
|
|
|
|
|
|
|
|
def _send_data_no_flag(dataset, epoch_num):
|
|
|
|
|
"""Engine dataset to write data to tdt queue directly."""
|
|
|
|
|
exec_dataset = dataset.__TRANSFER_DATASET__
|
|
|
|
|
exec_dataset.send(epoch_num)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatasetHelper:
|
|
|
|
|
"""
|
|
|
|
@ -54,7 +59,7 @@ class DatasetHelper:
|
|
|
|
|
>>> outputs = network(*inputs)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1):
|
|
|
|
|
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
|
|
|
|
|
check_bool(dataset_sink_mode)
|
|
|
|
|
check_int(sink_size)
|
|
|
|
|
if sink_size < -1 or sink_size == 0:
|
|
|
|
@ -74,7 +79,7 @@ class DatasetHelper:
|
|
|
|
|
iterclass = _DatasetIterMS
|
|
|
|
|
elif context.get_context("device_target") == "CPU":
|
|
|
|
|
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
|
|
|
|
|
self.iter = iterclass(dataset, sink_size)
|
|
|
|
|
self.iter = iterclass(dataset, sink_size, epoch_num)
|
|
|
|
|
else:
|
|
|
|
|
iterclass = _DatasetIterNormal
|
|
|
|
|
self.iter = iterclass(dataset)
|
|
|
|
@ -98,7 +103,7 @@ class DatasetHelper:
|
|
|
|
|
|
|
|
|
|
class _DatasetIter:
|
|
|
|
|
"""Base iter for dataset helper"""
|
|
|
|
|
def __init__(self, dataset, sink_size):
|
|
|
|
|
def __init__(self, dataset, sink_size, epoch_num):
|
|
|
|
|
self.dataset = dataset
|
|
|
|
|
self.sink_size = sink_size
|
|
|
|
|
self.sink_count = 1
|
|
|
|
@ -110,9 +115,9 @@ class _DatasetIter:
|
|
|
|
|
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
|
|
|
|
|
|
|
|
|
if not hasattr(dataset, '__no_send__'):
|
|
|
|
|
_send_data(dataset)
|
|
|
|
|
_send_data(dataset, epoch_num)
|
|
|
|
|
else:
|
|
|
|
|
_send_data(dataset)
|
|
|
|
|
_send_data_no_flag(dataset, epoch_num)
|
|
|
|
|
|
|
|
|
|
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
|
|
|
|
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
|
|
|
@ -156,8 +161,8 @@ class _DatasetIter:
|
|
|
|
|
|
|
|
|
|
class _DatasetIterGE(_DatasetIter):
|
|
|
|
|
"""Iter for GE."""
|
|
|
|
|
def __init__(self, dataset, sink_size):
|
|
|
|
|
super().__init__(dataset, sink_size)
|
|
|
|
|
def __init__(self, dataset, sink_size, epoch_num):
|
|
|
|
|
super().__init__(dataset, sink_size, epoch_num)
|
|
|
|
|
self.sink_count = self.get_sink_count(dataset)
|
|
|
|
|
batch_expand_num = 1
|
|
|
|
|
if _need_to_full():
|
|
|
|
@ -172,8 +177,8 @@ class _DatasetIterGE(_DatasetIter):
|
|
|
|
|
|
|
|
|
|
class _DatasetIterMSLoopSink(_DatasetIter):
|
|
|
|
|
"""Iter for context (device_target=Ascend)"""
|
|
|
|
|
def __init__(self, dataset, sink_size):
|
|
|
|
|
super().__init__(dataset, sink_size)
|
|
|
|
|
def __init__(self, dataset, sink_size, epoch_num):
|
|
|
|
|
super().__init__(dataset, sink_size, epoch_num)
|
|
|
|
|
self.sink_count = self.get_sink_count(dataset)
|
|
|
|
|
ms_role = os.getenv("MS_ROLE")
|
|
|
|
|
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
|
|
|
@ -193,8 +198,8 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|
|
|
|
|
|
|
|
|
class _DatasetIterMS(_DatasetIter):
|
|
|
|
|
"""Iter for MS(enable_loop_sink=False)."""
|
|
|
|
|
def __init__(self, dataset, sink_size):
|
|
|
|
|
super().__init__(dataset, sink_size)
|
|
|
|
|
def __init__(self, dataset, sink_size, epoch_num):
|
|
|
|
|
super().__init__(dataset, sink_size, epoch_num)
|
|
|
|
|
if sink_size > 0:
|
|
|
|
|
self.sink_count = sink_size
|
|
|
|
|
else:
|
|
|
|
@ -206,8 +211,8 @@ class _DatasetIterMS(_DatasetIter):
|
|
|
|
|
|
|
|
|
|
class _DatasetIterPSLite(_DatasetIter):
|
|
|
|
|
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
|
|
|
|
|
def __init__(self, dataset, sink_size):
|
|
|
|
|
super().__init__(dataset, sink_size)
|
|
|
|
|
def __init__(self, dataset, sink_size, epoch_num):
|
|
|
|
|
super().__init__(dataset, sink_size, epoch_num)
|
|
|
|
|
self.sink_count = 1
|
|
|
|
|
self.sink_size = 1
|
|
|
|
|
self.op = None
|
|
|
|
|