fix bnupdate_eltwise_eltwise's cycle fusion

pull/6681/head
huanghui 4 years ago
parent fe934520e6
commit 90dfecfb00

@ -24,6 +24,7 @@
#include "base/core_ops.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/fusion_id_allocator.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
@ -59,6 +60,10 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
auto bnupdate = getitem->input(1);
MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
if (cnode->size() == ELTWISE_DOUBLE_IN_INPUT_SIZE &&
IsDepend(kernel_graph, cnode->input(2), {relu_input, bnupdate})) {
return;
}
std::vector<int> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);

@ -97,11 +97,11 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
auto mul0 = mul0_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul0);
if (IsDepend(graph, mul0->input(1), reduce_sum)) {
if (IsDepend(*graph, mul0->input(1), {reduce_sum})) {
MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion";
return true;
}
if (IsDepend(graph, mul1->input(1), mul0)) {
if (IsDepend(*graph, mul1->input(1), {mul0})) {
MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion";
return true;
}

@ -39,11 +39,9 @@ std::vector<int> Convert2Int(const std::vector<size_t> &v) {
return result;
}
bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes) {
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> node_list = TopoSort(graph.get_return());
std::map<AnfNodePtr, std::set<AnfNodePtr>> control_depend_map;
for (auto &nd : node_list) {
MS_EXCEPTION_IF_NULL(nd);
@ -60,29 +58,29 @@ bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodeP
}
}
FuncGraphManagerPtr manager = graph->manager();
FuncGraphManagerPtr manager = graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> seen_node;
std::deque<AnfNodePtr> todo{node1};
std::deque<AnfNodePtr> todo{node};
while (!todo.empty()) {
AnfNodePtr node = todo.front();
AnfNodePtr nd = todo.front();
todo.pop_front();
if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
if (seen_node.count(nd) > 0 || !manager->all_nodes().contains(nd)) {
continue;
}
(void)seen_node.insert(node);
(void)seen_node.insert(nd);
if (node == node2) {
if (std::any_of(nodes.begin(), nodes.end(), [&nd](const AnfNodePtr &item) { return nd == item; })) {
return true;
}
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (nd->isa<CNode>()) {
auto cnode = nd->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto inputs = cnode->inputs();
(void)todo.insert(todo.end(), inputs.begin(), inputs.end());
}
auto it = control_depend_map.find(node);
auto it = control_depend_map.find(nd);
if (it != control_depend_map.end()) {
(void)todo.insert(todo.end(), it->second.begin(), it->second.end());
}

@ -119,8 +119,8 @@ enum ConvBn1Output {
std::vector<int> Convert2Int(const std::vector<size_t> &v);
// check whether node1 depends on node2 or not
bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2);
// check whether node depends on either of nodes or not
bool IsDepend(const FuncGraph &graph, const AnfNodePtr &node, const std::vector<AnfNodePtr> &nodes);
bool UnVisited(const BaseRef &n);

Loading…
Cancel
Save