diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 255fd74972..5d68aaaec6 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -429,7 +429,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { return false; } if (IsInBlackList(prim)) { - MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); + MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name(); return false; } // get_next is not in the forward graph, we need mark the get_next as the forward node @@ -1286,7 +1286,11 @@ std::vector ExtractShape(const CNodePtr &node) { return shape_all; } -std::pair FindParallelCareNode(const AnfNodePtr &node) { +std::pair FindParallelCareNode(const AnfNodePtr &node, int32_t recursion_num) { + if (recursion_num >= RECURSION_LIMIT) { + return std::make_pair(nullptr, 0); + } + MS_EXCEPTION_IF_NULL(node); FuncGraphPtr func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -1308,8 +1312,11 @@ std::pair FindParallelCareNode(const AnfNodePtr &node) { } if (IsParallelCareNode(cnode) && cnode->has_user_data()) { return node_pair; - } else if (FindParallelCareNode(node_pair.first).first != nullptr) { - return FindParallelCareNode(node_pair.first); + } else { + auto tmp_pair = FindParallelCareNode(node_pair.first, recursion_num + 1); + if (tmp_pair.first != nullptr) { + return tmp_pair; + } } } return std::make_pair(nullptr, 0); @@ -1320,7 +1327,7 @@ std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNode MS_EXCEPTION_IF_NULL(parameter); FuncGraphManagerPtr manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); - std::pair prim_anf_node_pair = FindParallelCareNode(parameter); + std::pair prim_anf_node_pair = FindParallelCareNode(parameter, 0); if (prim_anf_node_pair.first != nullptr) { return prim_anf_node_pair; } else { diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 47fb8e78c2..84a9aeb5fb 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -36,6 +36,7 @@ using OperatorInfoPtr = std::shared_ptr; namespace mindspore { namespace parallel { const uint64_t kUSecondInSecond = 1000000; +const int32_t RECURSION_LIMIT = 3; struct LossNodeInfo { bool has_tuple_getitem = false; @@ -104,8 +105,6 @@ std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const // Extract shape from anfnode std::vector ExtractShape(const CNodePtr &node); -std::pair FindParallelCareNode(const AnfNodePtr &node); - // Find finally sub graph std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter);