|
|
|
@ -545,6 +545,39 @@ void BufferFusion::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
|
|
|
|
FusedNodeRecord *candidate_fusion, bool is_order) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
|
|
|
|
auto manager = kernel_graph.manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
|
|
|
|
|
if (is_order) {
|
|
|
|
|
// DepthwiseConvolution--->Elemwise
|
|
|
|
|
auto depthwise_conv = cnode->input(1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(depthwise_conv);
|
|
|
|
|
if (cnode->isa<CNode>() && AnfAlgo::GetCNodeName(depthwise_conv) == prim::kPrimDepthwiseConv2dNative->name()) {
|
|
|
|
|
std::vector<int> output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())};
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv);
|
|
|
|
|
std::unordered_set<AnfNodePtr> record{cnode, depthwise_conv};
|
|
|
|
|
candidate_fusion->push_back(record);
|
|
|
|
|
SetRecordFusionId(record);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// Elemwise-->DepthwiseConvolution
|
|
|
|
|
auto relu = cnode->input(1);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(relu);
|
|
|
|
|
if (cnode->isa<CNode>() && AnfAlgo::GetCNodeName(relu) == prim::kPrimRelu->name() ||
|
|
|
|
|
AnfAlgo::GetCNodeName() == kReluV2OpName) {
|
|
|
|
|
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu].size())};
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu);
|
|
|
|
|
std::unordered_set<AnfNodePtr> record{cnode, relu};
|
|
|
|
|
candidate_fusion->push_back(record);
|
|
|
|
|
SetRecordFusionId(record);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
|
|
|
|
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
|
|
|
@ -563,7 +596,11 @@ void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph,
|
|
|
|
|
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion);
|
|
|
|
|
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) {
|
|
|
|
|
MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion);
|
|
|
|
|
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimDepthwiseConv2dNative->name()) {
|
|
|
|
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
|
|
|
|
|
}
|
|
|
|
|
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) {
|
|
|
|
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|