updated conv bn fuse pass to make it compatible with latest batch_norm op (#31272)

test_model_benchmark_ci
alncat 5 years ago committed by GitHub
parent a37658daff
commit bfb8a64234
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -790,27 +790,31 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Scale");
->assert_is_op_input("batch_norm", "Scale")
->assert_has_n_outputs(1);
// BN Bias
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Bias");
->assert_is_op_input("batch_norm", "Bias")
->assert_has_n_outputs(1);
// BN Mean
auto *bn_mean_var = pattern->NewNode(bn_mean_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Mean");
->assert_is_op_input("batch_norm", "Mean")
->assert_has_n_outputs(1);
// BN Variance
auto *bn_variance_var = pattern->NewNode(bn_variance_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Variance");
->assert_is_op_input("batch_norm", "Variance")
->assert_has_n_outputs(1);
// BN output
auto *bn_out_var = pattern->NewNode(bn_out_repr())
->AsOutput()
->assert_is_op_output("batch_norm");
->assert_is_op_output("batch_norm", "Y");
auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr())
->AsOutput()

Loading…
Cancel
Save