|
|
|
@ -74,11 +74,8 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker
|
|
|
|
|
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
|
|
|
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
|
|
|
|
auto eltwise_input = cnode->input(1);
|
|
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
|
|
|
|
|
if (eltwise_input->isa<CNode>() &&
|
|
|
|
|
AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) {
|
|
|
|
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
|
|
|
|
|
}
|
|
|
|
|
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) {
|
|
|
|
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
|
|
|
|
|
}
|
|
|
|
|
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) {
|
|
|
|
|
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false);
|
|
|
|
|