From edce230586fcd58725eb32d501c5a07c39cc8c14 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Tue, 9 Mar 2021 21:28:26 +0800 Subject: [PATCH] fix wrap twice problem --- mindspore/train/dataset_helper.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index e42ca26715..6ca4098937 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -39,6 +39,22 @@ def _send_data_no_flag(dataset, epoch_num): exec_dataset.send(epoch_num) +def _dynamic_sink_scenario(dataset, dataset_iter): + """Special scenario with dynamic shape and sink_size=1.""" + flag = False + ms_role = os.getenv("MS_ROLE") + if hasattr(dataset_iter, "sink_size") and \ + dataset_iter.sink_size == 1 and \ + dataset.get_dataset_size() != 1 and \ + hasattr(dataset_iter, "sink_count") and \ + dataset_iter.sink_count == 1 and \ + context.get_context("device_target") == "Ascend" and \ + context.get_context("mode") == context.GRAPH_MODE and \ + ms_role != "MS_WORKER": + flag = True + return flag + + class _DataWrapper(nn.Cell): """ Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the @@ -107,14 +123,7 @@ def connect_network_with_dataset(network, dataset_helper): return network queue_name = dataset.__transfer_dataset__.queue_name - if hasattr(dataset_iter, "sink_size") and \ - dataset_iter.sink_size == 1 and \ - hasattr(dataset_iter, "sink_count") and \ - dataset_iter.sink_count == 1 and \ - context.get_context("device_target") == "Ascend" and \ - context.get_context("mode") == context.GRAPH_MODE and \ - ms_role != "MS_WORKER": - + if _dynamic_sink_scenario(dataset, dataset_iter): if not hasattr(dataset_iter, '__network__'): dataset_iter.__network__ = network network = dataset_iter.__network__