|
|
|
@ -39,6 +39,22 @@ def _send_data_no_flag(dataset, epoch_num):
|
|
|
|
|
exec_dataset.send(epoch_num)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dynamic_sink_scenario(dataset, dataset_iter):
|
|
|
|
|
"""Special scenario with dynamic shape and sink_size=1."""
|
|
|
|
|
flag = False
|
|
|
|
|
ms_role = os.getenv("MS_ROLE")
|
|
|
|
|
if hasattr(dataset_iter, "sink_size") and \
|
|
|
|
|
dataset_iter.sink_size == 1 and \
|
|
|
|
|
dataset.get_dataset_size() != 1 and \
|
|
|
|
|
hasattr(dataset_iter, "sink_count") and \
|
|
|
|
|
dataset_iter.sink_count == 1 and \
|
|
|
|
|
context.get_context("device_target") == "Ascend" and \
|
|
|
|
|
context.get_context("mode") == context.GRAPH_MODE and \
|
|
|
|
|
ms_role != "MS_WORKER":
|
|
|
|
|
flag = True
|
|
|
|
|
return flag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DataWrapper(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
|
|
|
|
@ -107,14 +123,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
|
|
return network
|
|
|
|
|
|
|
|
|
|
queue_name = dataset.__transfer_dataset__.queue_name
|
|
|
|
|
if hasattr(dataset_iter, "sink_size") and \
|
|
|
|
|
dataset_iter.sink_size == 1 and \
|
|
|
|
|
hasattr(dataset_iter, "sink_count") and \
|
|
|
|
|
dataset_iter.sink_count == 1 and \
|
|
|
|
|
context.get_context("device_target") == "Ascend" and \
|
|
|
|
|
context.get_context("mode") == context.GRAPH_MODE and \
|
|
|
|
|
ms_role != "MS_WORKER":
|
|
|
|
|
|
|
|
|
|
if _dynamic_sink_scenario(dataset, dataset_iter):
|
|
|
|
|
if not hasattr(dataset_iter, '__network__'):
|
|
|
|
|
dataset_iter.__network__ = network
|
|
|
|
|
network = dataset_iter.__network__
|
|
|
|
|