|
|
@ -42,16 +42,28 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
|
|
|
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
|
|
|
MS_EXCEPTION_IF_NULL(input_cnode);
|
|
|
|
MS_EXCEPTION_IF_NULL(input_cnode);
|
|
|
|
auto double_in_eltwise_input = input_cnode->input(1);
|
|
|
|
auto double_in_eltwise_input = input_cnode->input(2);
|
|
|
|
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
|
|
|
|
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
|
|
|
|
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) ||
|
|
|
|
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) {
|
|
|
|
fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput)) {
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput) &&
|
|
|
|
|
|
|
|
!fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) {
|
|
|
|
(void)record.insert(double_in_eltwise_input);
|
|
|
|
(void)record.insert(double_in_eltwise_input);
|
|
|
|
candidate_fusion->push_back(record);
|
|
|
|
candidate_fusion->push_back(record);
|
|
|
|
SetRecordFusionId(record);
|
|
|
|
SetRecordFusionId(record);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
auto double_in_eltwise_input_1 = input_cnode->input(1);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1);
|
|
|
|
|
|
|
|
if (!double_in_eltwise_input_1->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input_1)) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input_1, prim::kPrimConv2DBackpropInput) &&
|
|
|
|
|
|
|
|
!fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input_1)) {
|
|
|
|
|
|
|
|
(void)record.insert(double_in_eltwise_input_1);
|
|
|
|
|
|
|
|
candidate_fusion->push_back(record);
|
|
|
|
|
|
|
|
SetRecordFusionId(record);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|