|
|
|
|
@ -790,27 +790,31 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
|
|
|
|
|
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
|
|
|
|
|
->AsInput()
|
|
|
|
|
->assert_is_persistable_var()
|
|
|
|
|
->assert_is_op_input("batch_norm", "Scale");
|
|
|
|
|
->assert_is_op_input("batch_norm", "Scale")
|
|
|
|
|
->assert_has_n_outputs(1);
|
|
|
|
|
// BN Bias
|
|
|
|
|
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
|
|
|
|
|
->AsInput()
|
|
|
|
|
->assert_is_persistable_var()
|
|
|
|
|
->assert_is_op_input("batch_norm", "Bias");
|
|
|
|
|
->assert_is_op_input("batch_norm", "Bias")
|
|
|
|
|
->assert_has_n_outputs(1);
|
|
|
|
|
// BN Mean
|
|
|
|
|
auto *bn_mean_var = pattern->NewNode(bn_mean_repr())
|
|
|
|
|
->AsInput()
|
|
|
|
|
->assert_is_persistable_var()
|
|
|
|
|
->assert_is_op_input("batch_norm", "Mean");
|
|
|
|
|
->assert_is_op_input("batch_norm", "Mean")
|
|
|
|
|
->assert_has_n_outputs(1);
|
|
|
|
|
// BN Variance
|
|
|
|
|
auto *bn_variance_var = pattern->NewNode(bn_variance_repr())
|
|
|
|
|
->AsInput()
|
|
|
|
|
->assert_is_persistable_var()
|
|
|
|
|
->assert_is_op_input("batch_norm", "Variance");
|
|
|
|
|
->assert_is_op_input("batch_norm", "Variance")
|
|
|
|
|
->assert_has_n_outputs(1);
|
|
|
|
|
|
|
|
|
|
// BN output
|
|
|
|
|
auto *bn_out_var = pattern->NewNode(bn_out_repr())
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output("batch_norm");
|
|
|
|
|
->assert_is_op_output("batch_norm", "Y");
|
|
|
|
|
|
|
|
|
|
auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr())
|
|
|
|
|
->AsOutput()
|
|
|
|
|
|