From dc95c63c03942fcbf204709efbd922b68105e7f0 Mon Sep 17 00:00:00 2001 From: lingyunli63 Date: Mon, 9 Nov 2020 21:12:00 +0800 Subject: [PATCH] remove multiple circles --- .../graph_kernel/composite_ops_fusion.cc | 51 ++++++++++++++----- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc index 60771d759b..27d5cd0ead 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc @@ -39,6 +39,27 @@ namespace mindspore { namespace opt { +namespace { +std::vector DeepLinkedGraphSearch(const std::vector &roots, const IncludeFunc &include) { + std::vector inputs; + for (auto &root : roots) { + auto tmp = DeepLinkedGraphSearch(root, include); + inputs.insert(inputs.end(), tmp.begin(), tmp.end()); + } + return inputs; +} + +std::vector DeepUsersSearch(const std::vector &roots, const IncludeFunc &include, + const FuncGraphManagerPtr &mng) { + std::vector users; + for (auto &root : roots) { + auto tmp = DeepUsersSearch(root, include, mng); + users.insert(users.end(), tmp.begin(), tmp.end()); + } + return users; +} +} // namespace + bool IsBasicFuseOp(const AnfNodePtr &node, bool is_before_kernel_select) { #if ENABLE_D std::vector basic_ops = { @@ -186,31 +207,33 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, GraphKernelI } bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &check_node, - std::set *cached_unconnected_set, AnfNodePtr *circle_node) { + std::set *cached_unconnected_set, std::vector *circle_nodes) { if (!check_node->isa() || !fused_op_set.count(check_node)) { return false; } + circle_nodes->clear(); + std::set cached_done_set; auto cnode = check_node->cast(); const auto &inputs = cnode->inputs(); // there is a input not in fused_op_set, but the input depends on the fused_op_set - bool has_circle = false; for (auto input : inputs) { if (input->isa() && !fused_op_set.count(input)) { + bool has_circle = false; std::set done; std::vector todos = {input}; while (!todos.empty()) { auto node = todos.back(); todos.pop_back(); - if (done.count(node) || cached_unconnected_set->count(node)) { + if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) { continue; } done.insert(node); if (fused_op_set.count(node)) { has_circle = true; - *circle_node = node; - break; + circle_nodes->push_back(node); + continue; } if (node->isa()) { @@ -224,13 +247,15 @@ bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &che } if (has_circle) { - return true; + cached_done_set.insert(done.begin(), done.end()); + } else { + cached_unconnected_set->insert(done.begin(), done.end()); } - cached_unconnected_set->insert(done.begin(), done.end()); + done.clear(); } } - return false; + return !circle_nodes->empty(); } std::vector RemoveCircle(const std::vector &fused_op, bool is_backward) { @@ -242,17 +267,19 @@ std::vector RemoveCircle(const std::vector &fused_op, bo } return EXCLUDE; }; + + std::vector circle_nodes; for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { - AnfNodePtr circle_node; - bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_node); + circle_nodes.clear(); + bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes); // delete the circle node and the node which depend on the circle node in fused op if (has_circle) { auto mng = (*iter)->func_graph()->manager(); std::vector erase_nodes; if (is_backward) { - erase_nodes = DeepUsersSearch(circle_node, include, mng); + erase_nodes = DeepUsersSearch(circle_nodes, include, mng); } else { - erase_nodes = DeepLinkedGraphSearch(circle_node, include); + erase_nodes = DeepLinkedGraphSearch(circle_nodes, include); } for (auto erase_node : erase_nodes) { fused_op_set.erase(erase_node);