|
|
|
@ -519,50 +519,41 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
|
|
|
|
|
|
|
|
|
|
PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
PDNode* x, bool with_bias) {
|
|
|
|
|
// Create Operators
|
|
|
|
|
PDNode* elementwise_add_op{nullptr};
|
|
|
|
|
// mul op
|
|
|
|
|
auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul");
|
|
|
|
|
if (with_bias) {
|
|
|
|
|
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
|
|
|
|
|
->assert_is_op("elementwise_add");
|
|
|
|
|
}
|
|
|
|
|
// Create variables
|
|
|
|
|
// w
|
|
|
|
|
auto* mul_weight_var = pattern->NewNode(name_scope, "w")
|
|
|
|
|
->AsInput()
|
|
|
|
|
->assert_is_persistable_var()
|
|
|
|
|
->assert_is_op_nth_input("mul", "Y", 0);
|
|
|
|
|
PDNode* mul_out_var{nullptr};
|
|
|
|
|
->assert_is_op_input("mul", "Y");
|
|
|
|
|
|
|
|
|
|
PDNode* fc_out{nullptr};
|
|
|
|
|
if (with_bias) {
|
|
|
|
|
PDNode* elementwise_add_op{nullptr};
|
|
|
|
|
PDNode *mul_out_var{nullptr}, *bias{nullptr};
|
|
|
|
|
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
|
|
|
|
|
->assert_is_op("elementwise_add");
|
|
|
|
|
// intermediate variable, will be removed in the IR after fuse.
|
|
|
|
|
mul_out_var = pattern->NewNode(name_scope, "mul_out")
|
|
|
|
|
->AsIntermediate()
|
|
|
|
|
->assert_is_only_output_of_op("mul")
|
|
|
|
|
->assert_is_op_input("elementwise_add");
|
|
|
|
|
}
|
|
|
|
|
PDNode *bias{nullptr}, *fc_out{nullptr};
|
|
|
|
|
if (with_bias) {
|
|
|
|
|
->assert_is_op_input("elementwise_add", "X");
|
|
|
|
|
// bias
|
|
|
|
|
bias = pattern->NewNode(name_scope, "fc_bias")
|
|
|
|
|
->assert_is_op_input("elementwise_add")
|
|
|
|
|
->AsInput();
|
|
|
|
|
->AsInput()
|
|
|
|
|
->assert_is_persistable_var()
|
|
|
|
|
->assert_is_op_input("elementwise_add", "Y");
|
|
|
|
|
// output
|
|
|
|
|
fc_out = pattern->NewNode(name_scope, "fc_out")
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output("elementwise_add");
|
|
|
|
|
->assert_is_op_output("elementwise_add", "Out");
|
|
|
|
|
mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var});
|
|
|
|
|
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
|
|
|
|
|
} else {
|
|
|
|
|
fc_out = pattern->NewNode(name_scope, "fc_out")
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output("mul");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (with_bias) {
|
|
|
|
|
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({mul_out_var});
|
|
|
|
|
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
|
|
|
|
|
} else {
|
|
|
|
|
->assert_is_op_output("mul", "Out");
|
|
|
|
|
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return fc_out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -609,6 +600,10 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
NEW_NODE(gru, BatchResetHiddenPrev, output);
|
|
|
|
|
NEW_NODE(gru, BatchHidden, output);
|
|
|
|
|
|
|
|
|
|
BatchGate->AsIntermediate();
|
|
|
|
|
BatchResetHiddenPrev->AsIntermediate();
|
|
|
|
|
BatchHidden->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
gru_op->LinksFrom({x, Weight, Bias});
|
|
|
|
|
gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden});
|
|
|
|
|
return Hidden;
|
|
|
|
|