|
|
|
@ -27,13 +27,13 @@ from ..ops import operations as P
|
|
|
|
|
def _send_data(dataset, epoch_num):
|
|
|
|
|
"""Engine dataset to write data to tdt queue."""
|
|
|
|
|
if not hasattr(dataset, '__has_sent__'):
|
|
|
|
|
exec_dataset = dataset.__TRANSFER_DATASET__
|
|
|
|
|
exec_dataset = dataset.__transfer_dataset__
|
|
|
|
|
exec_dataset.send(epoch_num)
|
|
|
|
|
dataset.__has_sent__ = True
|
|
|
|
|
|
|
|
|
|
def _send_data_no_flag(dataset, epoch_num):
|
|
|
|
|
"""Engine dataset to write data to tdt queue directly."""
|
|
|
|
|
exec_dataset = dataset.__TRANSFER_DATASET__
|
|
|
|
|
exec_dataset = dataset.__transfer_dataset__
|
|
|
|
|
exec_dataset.send(epoch_num)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -88,11 +88,13 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
|
|
if isinstance(dataset_iter, _DatasetIterNormal):
|
|
|
|
|
raise RuntimeError("Dataset should be connected with network only in sink mode.")
|
|
|
|
|
|
|
|
|
|
if not hasattr(dataset, '__ME_INITED__') and (context.get_context("device_target") == "Ascend" \
|
|
|
|
|
or context.get_context("device_target") == "GPU") and not context.get_context("enable_ge"):
|
|
|
|
|
dataset.__ME_INITED__ = True
|
|
|
|
|
if not hasattr(dataset, '__me_inited__') and (context.get_context("device_target") == "Ascend"
|
|
|
|
|
or context.get_context("device_target") == "GPU") and not \
|
|
|
|
|
context.get_context("enable_ge"):
|
|
|
|
|
dataset.__me_inited__ = True
|
|
|
|
|
|
|
|
|
|
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
|
|
|
|
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
|
|
|
|
queue_name = dataset.__transfer_dataset__.queue_name
|
|
|
|
|
|
|
|
|
|
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
|
|
|
|
return network
|
|
|
|
@ -175,18 +177,18 @@ class _DatasetIter:
|
|
|
|
|
self.sink_size = sink_size
|
|
|
|
|
self.sink_count = 1
|
|
|
|
|
|
|
|
|
|
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
|
|
|
|
if not hasattr(dataset, '__transfer_dataset__'):
|
|
|
|
|
if hasattr(dataset, '__loop_size__'):
|
|
|
|
|
self.sink_size = dataset.__loop_size__
|
|
|
|
|
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
|
|
|
|
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size)
|
|
|
|
|
|
|
|
|
|
if not hasattr(dataset, '__no_send__'):
|
|
|
|
|
_send_data(dataset, epoch_num)
|
|
|
|
|
else:
|
|
|
|
|
_send_data_no_flag(dataset, epoch_num)
|
|
|
|
|
|
|
|
|
|
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
|
|
|
|
self.continue_send = dataset.__TRANSFER_DATASET__.continue_send
|
|
|
|
|
self.stop_send = dataset.__transfer_dataset__.stop_send
|
|
|
|
|
self.continue_send = dataset.__transfer_dataset__.continue_send
|
|
|
|
|
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
@ -273,7 +275,7 @@ class _DatasetIterMS(_DatasetIter):
|
|
|
|
|
else:
|
|
|
|
|
self.sink_count = dataset.get_dataset_size()
|
|
|
|
|
|
|
|
|
|
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
|
|
|
|
queue_name = dataset.__transfer_dataset__.queue_name
|
|
|
|
|
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|