|
|
|
@ -73,13 +73,16 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An
|
|
|
|
|
return mul0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &reduce_sum) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf,
|
|
|
|
|
const AnfNodePtr &reduce_sum) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mul0_anf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mul1_anf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(reduce_sum);
|
|
|
|
|
if (!mul0_anf->isa<CNode>()) {
|
|
|
|
|
if (!mul0_anf->isa<CNode>() || !mul1_anf->isa<CNode>()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
auto mul1 = mul1_anf->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mul1);
|
|
|
|
|
auto mul0 = mul0_anf->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mul0);
|
|
|
|
|
|
|
|
|
@ -88,20 +91,14 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
if (manager->node_users().find(reduce_sum) == manager->node_users().end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node has no output in manager";
|
|
|
|
|
if (IsDepend(graph, mul0->input(1), reduce_sum)) {
|
|
|
|
|
MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
const AnfNodeIndexSet &outputs_set = manager->node_users()[reduce_sum];
|
|
|
|
|
auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul0](const std::pair<AnfNodePtr, int> &node_index) {
|
|
|
|
|
return node_index.first == mul0->input(1) || node_index.first == mul0;
|
|
|
|
|
});
|
|
|
|
|
if (it != outputs_set.end()) {
|
|
|
|
|
MS_LOG(INFO) << "ReduceSum's output node is mul0's input or mul0! If do fusion, graph will exist a circle";
|
|
|
|
|
if (IsDepend(graph, mul1->input(1), mul0)) {
|
|
|
|
|
MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
@ -131,7 +128,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
|
|
|
|
|
MS_LOG(INFO) << "Mul0 do not exist, quit fusion";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (QuitFusion(graph, mul0, node)) {
|
|
|
|
|
if (QuitFusion(graph, mul0, mul1, node)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|