diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 7e9bb62aab..5c05b3fdcc 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -538,11 +538,16 @@ void KernelGraph::UpdateControlDependRelations(const std::vector &de MS_EXCEPTION_IF_NULL(depend_node); std::vector prior_nodes = {prior_node}; std::vector depend_nodes = {depend_node}; - MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString(); - if (prior_node->isa()) { + int depend_mode = 0; + if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { + depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); + } + MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() + << "], depend_mode :" << depend_mode << "."; + if (prior_node->isa() && depend_mode == 1) { prior_nodes = GetOutputNodes(prior_node); } - if (depend_node->isa()) { + if (depend_node->isa() && depend_mode == 1) { depend_nodes = GetOutputNodes(depend_node); } for (auto &first_node : prior_nodes) { diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 972d8df319..7380ef501f 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -246,6 +246,7 @@ constexpr auto kTupleGetItemInputSize = 3; // index define of control depend constexpr auto kControlDependPriorIndex = 1; constexpr auto kControlDependBehindIndex = 2; +constexpr auto kControlDependMode = "depend_mode"; // index define of depend constexpr auto kRealInputIndexInDepend = 1; constexpr auto kDependAttachNodeIndex = 2;