|
|
|
@ -99,7 +99,8 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
|
|
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \
|
|
|
|
|
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
|
|
|
|
|
and context.get_context("device_target") == "Ascend" \
|
|
|
|
|
and context.get_context("mode") == context.GRAPH_MODE:
|
|
|
|
|
and context.get_context("mode") == context.GRAPH_MODE \
|
|
|
|
|
and ms_role != "MS_WORKER":
|
|
|
|
|
|
|
|
|
|
if not hasattr(dataset, '__network__'):
|
|
|
|
|
dataset.__network__ = network
|
|
|
|
|