|
|
|
@ -20,12 +20,8 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
|
|
|
|
|
std::string GenNodeName(const std::string& prefix, const std::string& name) {
|
|
|
|
|
return prefix + "/" + name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BuildPattern(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
bool with_fc_bias) {
|
|
|
|
|
static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
bool with_fc_bias) {
|
|
|
|
|
PDNode* x = pattern->NewNode(name_scope, "x")
|
|
|
|
|
->assert_is_op_input("mul")
|
|
|
|
|
->assert_var_not_persistable();
|
|
|
|
@ -35,8 +31,8 @@ void BuildPattern(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
VLOG(3) << "\n" << pattern->DotString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
bool with_fc_bias) {
|
|
|
|
|
static int BuildFusion(Graph* graph, const std::string& name_scope,
|
|
|
|
|
Scope* scope, bool with_fc_bias) {
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
auto* pattern = gpd.mutable_pattern();
|
|
|
|
|
|
|
|
|
@ -108,7 +104,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
|
|
|
|
|
auto* op = graph->CreateOpNode(&op_desc);
|
|
|
|
|
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
|
|
|
|
|
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
|
|
|
|
|
// auto* scope = graph->Get<Scope*>(kParamScopeAttr);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(x_n, op);
|
|
|
|
|
IR_NODE_LINK_TO(weight_x_n, op);
|
|
|
|
@ -189,5 +185,5 @@ std::unique_ptr<ir::Graph> FCGRUFusePass::ApplyImpl(
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulGRUFusePass);
|
|
|
|
|
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCGRUFusePass);
|
|
|
|
|
REGISTER_PASS(mul_gru_fuse_pass, paddle::framework::ir::MulGRUFusePass);
|
|
|
|
|
REGISTER_PASS(fc_gru_fuse_pass, paddle::framework::ir::FCGRUFusePass);
|
|
|
|
|