|
|
|
@ -45,16 +45,14 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
auto* scope = param_scope();
|
|
|
|
|
PADDLE_ENFORCE(scope);
|
|
|
|
|
|
|
|
|
|
std::string type = is_conv3d() ? "conv3d" : "conv2d";
|
|
|
|
|
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
auto* conv_input =
|
|
|
|
|
gpd.mutable_pattern()
|
|
|
|
|
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
|
|
|
|
|
->AsInput()
|
|
|
|
|
->assert_is_op_input(type, "Input");
|
|
|
|
|
->assert_is_op_input(type(), "Input");
|
|
|
|
|
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
|
|
|
|
|
conv_bias_pattern(conv_input, is_conv3d());
|
|
|
|
|
conv_bias_pattern(conv_input, type());
|
|
|
|
|
int found_conv_bias_count = 0;
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
@ -75,7 +73,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
// check if fuse can be done and if MKL-DNN should be used
|
|
|
|
|
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
|
|
|
|
|
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
|
|
|
|
|
VLOG(3) << "do not perform conv+bias fuse";
|
|
|
|
|
VLOG(3) << "do not perform " + type() + "+bias fuse";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -110,7 +108,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
|
|
|
|
|
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
|
|
|
|
|
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
|
|
|
|
|
desc.SetType(type);
|
|
|
|
|
desc.SetType(type());
|
|
|
|
|
|
|
|
|
|
for (auto& attr : conv->Op()->GetAttrMap()) {
|
|
|
|
|
desc.SetAttr(attr.first, attr.second);
|
|
|
|
@ -135,5 +133,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
|
|
|
|
|
paddle::framework::ir::ConvBiasFusePass);
|
|
|
|
|
REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
|
|
|
|
|
paddle::framework::ir::Conv2DTransposeBiasFusePass);
|
|
|
|
|
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
|
|
|
|
|
paddle::framework::ir::Conv3DBiasFusePass);
|
|
|
|
|