diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc index 60d3d560e5..60771d759b 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc @@ -186,8 +186,8 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI } bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &check_node, - std::set *cached_unconnected_set) { - if (!check_node->isa()) { + std::set *cached_unconnected_set, AnfNodePtr *circle_node) { + if (!check_node->isa() || !fused_op_set.count(check_node)) { return false; } @@ -209,6 +209,7 @@ bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &che done.insert(node); if (fused_op_set.count(node)) { has_circle = true; + *circle_node = node; break; } @@ -242,15 +243,16 @@ std::vector RemoveCircle(const std::vector &fused_op, bo return EXCLUDE; }; for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { - bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set); + AnfNodePtr circle_node; + bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node); // delete the circle node and the node which depend on the circle node in fused op if (has_circle) { auto mng = (*iter)->func_graph()->manager(); std::vector erase_nodes; if (is_backward) { - erase_nodes = DeepUsersSearch(*iter, include, mng); + erase_nodes = DeepUsersSearch(circle_node, include, mng); } else { - erase_nodes = DeepLinkedGraphSearch(*iter, include); + erase_nodes = DeepLinkedGraphSearch(circle_node, include); } for (auto erase_node : erase_nodes) { fused_op_set.erase(erase_node);