|
|
|
@ -536,22 +536,21 @@ PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
mul_out_var = pattern->NewNode(name_scope, "mul_out")
|
|
|
|
|
->AsIntermediate()
|
|
|
|
|
->assert_is_only_output_of_op("mul")
|
|
|
|
|
->assert_is_op_input("elementwise_add", "X");
|
|
|
|
|
->assert_is_op_input("elementwise_add");
|
|
|
|
|
// bias
|
|
|
|
|
bias = pattern->NewNode(name_scope, "fc_bias")
|
|
|
|
|
->AsInput()
|
|
|
|
|
->assert_is_persistable_var()
|
|
|
|
|
->assert_is_op_input("elementwise_add", "Y");
|
|
|
|
|
->assert_is_op_input("elementwise_add");
|
|
|
|
|
// output
|
|
|
|
|
fc_out = pattern->NewNode(name_scope, "fc_out")
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output("elementwise_add", "Out");
|
|
|
|
|
->assert_is_op_output("elementwise_add");
|
|
|
|
|
mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var});
|
|
|
|
|
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
|
|
|
|
|
} else {
|
|
|
|
|
fc_out = pattern->NewNode(name_scope, "fc_out")
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output("mul", "Out");
|
|
|
|
|
->assert_is_op_output("mul");
|
|
|
|
|
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
|
|
|
|
|
}
|
|
|
|
|
return fc_out;
|
|
|
|
|