fix bnupdate_eltwise_eltwise's cycle fusion

pull/6681/head
huanghui 4 years ago
parent fe934520e6
commit 90dfecfb00

@ -24,6 +24,7 @@
#include "base/core_ops.h" #include "base/core_ops.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "backend/optimizer/common/fusion_id_allocator.h" #include "backend/optimizer/common/fusion_id_allocator.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -59,6 +60,10 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
auto bnupdate = getitem->input(1); auto bnupdate = getitem->input(1);
MS_EXCEPTION_IF_NULL(bnupdate); MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
if (cnode->size() == ELTWISE_DOUBLE_IN_INPUT_SIZE &&
IsDepend(kernel_graph, cnode->input(2), {relu_input, bnupdate})) {
return;
}
std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) { for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first); MS_EXCEPTION_IF_NULL(out_getitem.first);

@ -97,11 +97,11 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
auto mul0 = mul0_anf->cast<CNodePtr>(); auto mul0 = mul0_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul0); MS_EXCEPTION_IF_NULL(mul0);
if (IsDepend(graph, mul0->input(1), reduce_sum)) { if (IsDepend(*graph, mul0->input(1), {reduce_sum})) {
MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion";
return true; return true;
} }
if (IsDepend(graph, mul1->input(1), mul0)) { if (IsDepend(*graph, mul1->input(1), {mul0})) {
MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion";
return true; return true;
} }

@ -39,11 +39,9 @@ std::vector<int> Convert2Int(const std::vector<size_t> &v) {
return result; return result;
} }
bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node1); std::vector<AnfNodePtr> node_list = TopoSort(graph.get_return());
MS_EXCEPTION_IF_NULL(node2);
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map; std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
for (auto &nd : node_list) { for (auto &nd : node_list) {
MS_EXCEPTION_IF_NULL(nd); MS_EXCEPTION_IF_NULL(nd);
@ -60,29 +58,29 @@ bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodeP
} }
} }
FuncGraphManagerPtr manager = graph->manager(); FuncGraphManagerPtr manager = graph.manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> seen_node; std::unordered_set<AnfNodePtr> seen_node;
std::deque<AnfNodePtr> todo{node1}; std::deque<AnfNodePtr> todo{node};
while (!todo.empty()) { while (!todo.empty()) {
AnfNodePtr node = todo.front(); AnfNodePtr nd = todo.front();
todo.pop_front(); todo.pop_front();
if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) {
continue; continue;
} }
(void)seen_node.insert(node); (void)seen_node.insert(nd);
if (node == node2) { if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
return true; return true;
} }
if (node->isa<CNode>()) { if (nd->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = nd->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto inputs = cnode->inputs(); auto inputs = cnode->inputs();
(void)todo.insert(todo.end(), inputs.begin(), inputs.end()); (void)todo.insert(todo.end(), inputs.begin(), inputs.end());
} }
auto it = control_depend_map.find(node); auto it = control_depend_map.find(nd);
if (it != control_depend_map.end()) { if (it != control_depend_map.end()) {
(void)todo.insert(todo.end(), it->second.begin(), it->second.end()); (void)todo.insert(todo.end(), it->second.begin(), it->second.end());
} }

@ -119,8 +119,8 @@ enum ConvBn1Output {
std::vector<int> Convert2Int(const std::vector<size_t> &v); std::vector<int> Convert2Int(const std::vector<size_t> &v);
// check whether node1 depends on node2 or not // check whether node depends on either of nodes or not
bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2); bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes);
bool UnVisited(const BaseRef &n); bool UnVisited(const BaseRef &n);

Loading…
Cancel
Save