|
|
@ -46,14 +46,16 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
|
|
|
|
auto* scope = param_scope();
|
|
|
|
auto* scope = param_scope();
|
|
|
|
PADDLE_ENFORCE(scope);
|
|
|
|
PADDLE_ENFORCE(scope);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string type = is_conv3d() ? "conv3d" : "conv2d";
|
|
|
|
|
|
|
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
auto* conv_input =
|
|
|
|
auto* conv_input =
|
|
|
|
gpd.mutable_pattern()
|
|
|
|
gpd.mutable_pattern()
|
|
|
|
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
|
|
|
|
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
|
|
|
|
->AsInput()
|
|
|
|
->AsInput()
|
|
|
|
->assert_is_op_input("conv2d", "Input");
|
|
|
|
->assert_is_op_input(type, "Input");
|
|
|
|
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
|
|
|
|
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
|
|
|
|
conv_bias_pattern(conv_input);
|
|
|
|
conv_bias_pattern(conv_input, is_conv3d());
|
|
|
|
int found_conv_bias_count = 0;
|
|
|
|
int found_conv_bias_count = 0;
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
Graph* g) {
|
|
|
|
Graph* g) {
|
|
|
@ -109,7 +111,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
|
|
|
|
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
|
|
|
|
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
|
|
|
|
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
|
|
|
|
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
|
|
|
|
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
|
|
|
|
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
|
|
|
|
desc.SetType("conv2d");
|
|
|
|
desc.SetType(type);
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& attr : conv->Op()->GetAttrMap()) {
|
|
|
|
for (auto& attr : conv->Op()->GetAttrMap()) {
|
|
|
|
desc.SetAttr(attr.first, attr.second);
|
|
|
|
desc.SetAttr(attr.first, attr.second);
|
|
|
@ -135,3 +137,5 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
|
|
|
|
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
|
|
|
|
paddle::framework::ir::ConvBiasFusePass);
|
|
|
|
paddle::framework::ir::ConvBiasFusePass);
|
|
|
|
|
|
|
|
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
|
|
|
|
|
|
|
|
paddle::framework::ir::Conv3DBiasFusePass);
|
|
|
|