From 90dfecfb002fba8ad012a2ae2a3c192684727d08 Mon Sep 17 00:00:00 2001 From: huanghui Date: Thu, 24 Sep 2020 11:35:57 +0800 Subject: [PATCH] fix bnupdate_eltwise_eltwise's cycle fusion --- .../bnupdate_eltwise_eltwise_fusion_pass.cc | 5 ++++ .../ir_fusion/confusion_mul_grad_fusion.cc | 4 +-- .../ccsrc/backend/optimizer/common/helper.cc | 26 +++++++++---------- .../ccsrc/backend/optimizer/common/helper.h | 4 +-- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc index 715fed7d79..dd9462059a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc @@ -24,6 +24,7 @@ #include "base/core_ops.h" #include "utils/ms_context.h" #include "backend/optimizer/common/fusion_id_allocator.h" +#include "backend/optimizer/common/helper.h" namespace mindspore { namespace opt { @@ -59,6 +60,10 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod auto bnupdate = getitem->input(1); MS_EXCEPTION_IF_NULL(bnupdate); if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { + if (cnode->size() == ELTWISE_DOUBLE_IN_INPUT_SIZE && + IsDepend(kernel_graph, cnode->input(2), {relu_input, bnupdate})) { + return; + } std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); for (auto out_getitem : manager->node_users()[bnupdate]) { MS_EXCEPTION_IF_NULL(out_getitem.first); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc index 6ccf3e29bd..37243fbeeb 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc @@ -97,11 +97,11 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf auto mul0 = mul0_anf->cast(); MS_EXCEPTION_IF_NULL(mul0); - if (IsDepend(graph, mul0->input(1), reduce_sum)) { + if (IsDepend(*graph, mul0->input(1), {reduce_sum})) { MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; return true; } - if (IsDepend(graph, mul1->input(1), mul0)) { + if (IsDepend(*graph, mul1->input(1), {mul0})) { MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; return true; } diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 8a7df6e9e3..f1b02728ad 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -39,11 +39,9 @@ std::vector Convert2Int(const std::vector &v) { return result; } -bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - std::vector node_list = TopoSort(graph->get_return()); +bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector &nodes) { + MS_EXCEPTION_IF_NULL(node); + std::vector node_list = TopoSort(graph.get_return()); std::map> control_depend_map; for (auto &nd : node_list) { MS_EXCEPTION_IF_NULL(nd); @@ -60,29 +58,29 @@ bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodeP } } - FuncGraphManagerPtr manager = graph->manager(); + FuncGraphManagerPtr manager = graph.manager(); MS_EXCEPTION_IF_NULL(manager); std::unordered_set seen_node; - std::deque todo{node1}; + std::deque todo{node}; while (!todo.empty()) { - AnfNodePtr node = todo.front(); + AnfNodePtr nd = todo.front(); todo.pop_front(); - if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { + if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) { continue; } - (void)seen_node.insert(node); + (void)seen_node.insert(nd); - if (node == node2) { + if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) { return true; } - if (node->isa()) { - auto cnode = node->cast(); + if (nd->isa()) { + auto cnode = nd->cast(); MS_EXCEPTION_IF_NULL(cnode); auto inputs = cnode->inputs(); (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); } - auto it = control_depend_map.find(node); + auto it = control_depend_map.find(nd); if (it != control_depend_map.end()) { (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); } diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 7865050a35..bd008f68bd 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -119,8 +119,8 @@ enum ConvBn1Output { std::vector Convert2Int(const std::vector &v); -// check whether node1 depends on node2 or not -bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2); +// check whether node depends on either of nodes or not +bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector &nodes); bool UnVisited(const BaseRef &n);