diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index abf50d6d58..412c10ff40 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1006,6 +1006,25 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { return node->has_default(); } +bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == kLabelGotoOpName && + (AnfAlgo::GetNodeAttr(cnode, kAttrLabelIndex) == label_index)) { + return true; + } else if (AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) { + auto label_list = AnfAlgo::GetNodeAttr>(cnode, kAttrLabelSwitchList); + if (std::find(label_list.begin(), label_list.end(), label_index) != label_list.end()) { + return true; + } + } + return false; +} + void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = static_cast(node->kernel_info()); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index a6a9c78ebe..9b85a77896 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -188,6 +188,8 @@ class AnfRuntimeAlgorithm { static bool IsNodeInGraphKernel(const AnfNodePtr &node); // check parameter is weight or data static bool IsParameterWeight(const ParameterPtr &node); + // checkout whether the anf node is include the label_index. + static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index); // set stream id of kernel,which will be set in stream assign and be used in stream generate static void SetStreamId(uint32_t stream_id, AnfNode *node); // get stream id diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 9a1bf7afcb..bd93666610 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -1238,7 +1238,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull root_graph) { MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs"; int32_t index = 0; std::vector child_graphs; - auto start_label = graph->get_start_label(); + auto start_label_id = AnfAlgo::GetNodeAttr(graph->get_start_label(), kAttrLabelIndex); auto end_node = graph->get_end_goto(); ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0); std::vector new_inputs = {std::make_shared(std::make_shared(kLabelSwitchOpName)), @@ -1247,9 +1247,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull root_graph) { auto kg = graphs_[graph_id]; auto nodes = kg->execution_order(); for (uint32_t i = 0; i < nodes.size(); i++) { - if (AnfAlgo::GetCNodeName(nodes[i]) == kLabelGotoOpName && - (AnfAlgo::GetNodeAttr(nodes[i], kAttrLabelIndex) == - AnfAlgo::GetNodeAttr(start_label, kAttrLabelIndex))) { + if (AnfAlgo::IsLabelIndexInNode(nodes[i], start_label_id)) { if (i < (nodes.size() - 1)) { new_inputs.push_back(nodes[i + 1]); } else {