|
|
|
@ -566,25 +566,26 @@ PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
return fc_out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define NEW_NODE(op__, arg__, io__) \
|
|
|
|
|
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
|
|
|
|
|
->assert_is_op_##io__(#op__, #arg__);
|
|
|
|
|
|
|
|
|
|
PDNode* patterns::LSTM(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
PDNode* x) {
|
|
|
|
|
x->assert_is_op_input("lstm", "Input");
|
|
|
|
|
auto* lstm_op = pattern->NewNode(name_scope, "lstm")->assert_is_op("lstm");
|
|
|
|
|
#define NEW_NODE(arg__, io__) \
|
|
|
|
|
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
|
|
|
|
|
->assert_is_op_##io__("lstm", #arg__);
|
|
|
|
|
|
|
|
|
|
// Currently, the H0 and C0 are optional
|
|
|
|
|
// TODO(Superjomn) upgrade the fuse framework to support optional.
|
|
|
|
|
// NEW_NODE(H0, input);
|
|
|
|
|
// NEW_NODE(C0, input);
|
|
|
|
|
NEW_NODE(Weight, input);
|
|
|
|
|
NEW_NODE(Bias, input);
|
|
|
|
|
NEW_NODE(lstm, Weight, input);
|
|
|
|
|
NEW_NODE(lstm, Bias, input);
|
|
|
|
|
|
|
|
|
|
NEW_NODE(Hidden, output);
|
|
|
|
|
NEW_NODE(Cell, output);
|
|
|
|
|
NEW_NODE(BatchGate, output);
|
|
|
|
|
NEW_NODE(BatchCellPreAct, output);
|
|
|
|
|
NEW_NODE(lstm, Hidden, output);
|
|
|
|
|
NEW_NODE(lstm, Cell, output);
|
|
|
|
|
NEW_NODE(lstm, BatchGate, output);
|
|
|
|
|
NEW_NODE(lstm, BatchCellPreAct, output);
|
|
|
|
|
|
|
|
|
|
lstm_op->LinksFrom({x, Weight, Bias});
|
|
|
|
|
lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct});
|
|
|
|
@ -595,26 +596,24 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
PDNode* x) {
|
|
|
|
|
x->assert_is_op_input("gru", "Input");
|
|
|
|
|
auto* gru_op = pattern->NewNode(name_scope, "gru")->assert_is_op("gru");
|
|
|
|
|
#define NEW_NODE(arg__, io__) \
|
|
|
|
|
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
|
|
|
|
|
->assert_is_op_##io__("gru", #arg__);
|
|
|
|
|
|
|
|
|
|
NEW_NODE(Weight, input);
|
|
|
|
|
NEW_NODE(gru, Weight, input);
|
|
|
|
|
// TODO(Superjomn): upgrade the fuse framework to support optional.
|
|
|
|
|
// H0 and bias are optional
|
|
|
|
|
NEW_NODE(Bias, input); // also optional
|
|
|
|
|
NEW_NODE(gru, Bias, input); // also optional
|
|
|
|
|
// NEW_NODE(H0, input);
|
|
|
|
|
|
|
|
|
|
NEW_NODE(Hidden, output);
|
|
|
|
|
NEW_NODE(gru, Hidden, output);
|
|
|
|
|
// below are intermediate
|
|
|
|
|
NEW_NODE(BatchGate, output);
|
|
|
|
|
NEW_NODE(BatchResetHiddenPrev, output);
|
|
|
|
|
NEW_NODE(BatchHidden, output);
|
|
|
|
|
NEW_NODE(gru, BatchGate, output);
|
|
|
|
|
NEW_NODE(gru, BatchResetHiddenPrev, output);
|
|
|
|
|
NEW_NODE(gru, BatchHidden, output);
|
|
|
|
|
|
|
|
|
|
gru_op->LinksFrom({x, Weight, Bias});
|
|
|
|
|
gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden});
|
|
|
|
|
return Hidden;
|
|
|
|
|
}
|
|
|
|
|
#undef NEW_NODE
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
|
} // namespace framework
|
|
|
|
|