|
|
|
@ -74,38 +74,25 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
op_desc.SetInput("Bias", {new_bias_var});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create temp variables.
|
|
|
|
|
const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
|
|
|
|
|
const std::string BatchedCellPreAct =
|
|
|
|
|
patterns::UniqueKey("BatchedCellPreAct");
|
|
|
|
|
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
|
|
|
|
|
const std::string CheckedCell = patterns::UniqueKey("CheckedCell");
|
|
|
|
|
|
|
|
|
|
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
|
|
|
|
|
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
|
|
|
|
|
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
|
|
|
|
|
scope->Var(CheckedCell)->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("H0", {});
|
|
|
|
|
op_desc.SetInput("C0", {});
|
|
|
|
|
op_desc.SetOutput("Hidden", {hidden->Name()});
|
|
|
|
|
op_desc.SetOutput("Cell", {cell->Name()});
|
|
|
|
|
op_desc.SetOutput("XX", {xx->Name()});
|
|
|
|
|
op_desc.SetOutput("BatchedGate", {BatchedGate});
|
|
|
|
|
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
|
|
|
|
|
op_desc.SetOutput("BatchedInput", {BatchedInput});
|
|
|
|
|
op_desc.SetOutput("CheckedCell", {CheckedCell});
|
|
|
|
|
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
|
|
|
|
|
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
|
|
|
|
|
// TODO(TJ): get from attr
|
|
|
|
|
op_desc.SetAttr("use_seq", true);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
|
|
|
|
|
auto& scope = graph->Get<Scope>(kParamScopeAttr);
|
|
|
|
|
// Create temp variables.
|
|
|
|
|
#define OP_SET_OUT(x) \
|
|
|
|
|
const std::string x = patterns::UniqueKey(#x); \
|
|
|
|
|
op_desc.SetOutput(#x, {x}); \
|
|
|
|
|
scope.Var(x)->GetMutable<LoDTensor>()
|
|
|
|
|
op_desc.SetOutput(#x, {x});
|
|
|
|
|
|
|
|
|
|
OP_SET_OUT(BatchedGate);
|
|
|
|
|
OP_SET_OUT(BatchedCellPreAct);
|
|
|
|
|
OP_SET_OUT(BatchedInput);
|
|
|
|
|
OP_SET_OUT(CheckedCell);
|
|
|
|
|
OP_SET_OUT(BatchedCell);
|
|
|
|
|
OP_SET_OUT(BatchedHidden);
|
|
|
|
|
OP_SET_OUT(ReorderedH0);
|
|
|
|
@ -113,11 +100,29 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
|
|
|
|
|
#undef OP_SET_OUT
|
|
|
|
|
|
|
|
|
|
auto* op = graph->CreateOpNode(&op_desc);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(input, op);
|
|
|
|
|
IR_NODE_LINK_TO(weight_x, op);
|
|
|
|
|
IR_NODE_LINK_TO(weight_h, op);
|
|
|
|
|
IR_NODE_LINK_TO(bias, op);
|
|
|
|
|
IR_NODE_LINK_TO(op, hidden);
|
|
|
|
|
|
|
|
|
|
#define IR_NODE(x) \
|
|
|
|
|
VarDesc key_##x(x); \
|
|
|
|
|
key_##x.SetPersistable(false); \
|
|
|
|
|
auto* node_##x = graph->CreateVarNode(&key_##x); \
|
|
|
|
|
IR_NODE_LINK_TO(op, node_##x);
|
|
|
|
|
|
|
|
|
|
IR_NODE(BatchedGate);
|
|
|
|
|
IR_NODE(BatchedCellPreAct);
|
|
|
|
|
IR_NODE(BatchedInput);
|
|
|
|
|
IR_NODE(CheckedCell);
|
|
|
|
|
IR_NODE(BatchedCell);
|
|
|
|
|
IR_NODE(BatchedHidden);
|
|
|
|
|
IR_NODE(ReorderedH0);
|
|
|
|
|
IR_NODE(ReorderedC0);
|
|
|
|
|
#undef IR_NODE
|
|
|
|
|
|
|
|
|
|
return op;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|