refine fc and gru pattern

fix-develop-build.sh
tensor-tang 7 years ago
parent 7eebb90523
commit c9bd2d50f1

@ -519,50 +519,41 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope, PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
PDNode* x, bool with_bias) { PDNode* x, bool with_bias) {
// Create Operators // mul op
PDNode* elementwise_add_op{nullptr};
auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul"); auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul");
if (with_bias) {
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
->assert_is_op("elementwise_add");
}
// Create variables
// w
auto* mul_weight_var = pattern->NewNode(name_scope, "w") auto* mul_weight_var = pattern->NewNode(name_scope, "w")
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_nth_input("mul", "Y", 0); ->assert_is_op_input("mul", "Y");
PDNode* mul_out_var{nullptr};
PDNode* fc_out{nullptr};
if (with_bias) { if (with_bias) {
PDNode* elementwise_add_op{nullptr};
PDNode *mul_out_var{nullptr}, *bias{nullptr};
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
->assert_is_op("elementwise_add");
// intermediate variable, will be removed in the IR after fuse. // intermediate variable, will be removed in the IR after fuse.
mul_out_var = pattern->NewNode(name_scope, "mul_out") mul_out_var = pattern->NewNode(name_scope, "mul_out")
->AsIntermediate() ->AsIntermediate()
->assert_is_only_output_of_op("mul") ->assert_is_only_output_of_op("mul")
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add", "X");
}
PDNode *bias{nullptr}, *fc_out{nullptr};
if (with_bias) {
// bias // bias
bias = pattern->NewNode(name_scope, "fc_bias") bias = pattern->NewNode(name_scope, "fc_bias")
->assert_is_op_input("elementwise_add") ->AsInput()
->AsInput(); ->assert_is_persistable_var()
->assert_is_op_input("elementwise_add", "Y");
// output // output
fc_out = pattern->NewNode(name_scope, "fc_out") fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput() ->AsOutput()
->assert_is_op_output("elementwise_add"); ->assert_is_op_output("elementwise_add", "Out");
mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
} else { } else {
fc_out = pattern->NewNode(name_scope, "fc_out") fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput() ->AsOutput()
->assert_is_op_output("mul"); ->assert_is_op_output("mul", "Out");
}
if (with_bias) {
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
} else {
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out}); mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
} }
return fc_out; return fc_out;
} }
@ -609,6 +600,10 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
NEW_NODE(gru, BatchResetHiddenPrev, output); NEW_NODE(gru, BatchResetHiddenPrev, output);
NEW_NODE(gru, BatchHidden, output); NEW_NODE(gru, BatchHidden, output);
BatchGate->AsIntermediate();
BatchResetHiddenPrev->AsIntermediate();
BatchHidden->AsIntermediate();
gru_op->LinksFrom({x, Weight, Bias}); gru_op->LinksFrom({x, Weight, Bias});
gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden}); gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden});
return Hidden; return Hidden;

Loading…
Cancel
Save