|
|
|
@ -136,22 +136,22 @@ void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
|
|
|
|
|
void PrepareParameters(Graph* graph, const Param& param) {
|
|
|
|
|
// Check parameters
|
|
|
|
|
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
|
|
|
|
|
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
|
|
|
|
|
auto& scope = graph->Get<Scope>(kParamScopeAttr);
|
|
|
|
|
|
|
|
|
|
// Create new parameters.
|
|
|
|
|
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
|
|
|
|
|
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
|
|
|
|
|
scope->Var(param.Hidden)->GetMutable<LoDTensor>();
|
|
|
|
|
scope->Var(param.Cell)->GetMutable<LoDTensor>();
|
|
|
|
|
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
|
|
|
|
|
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
|
|
|
|
|
scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
|
|
|
|
|
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.Var(param.LSTMWeight)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.Var(param.LSTMBias)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.Var(param.Hidden)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.Var(param.Cell)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.Var(param.AttentionedX)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.Var(param.LSTMX)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.Var(param.LSTMOUT)->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
#define GATE_W(name__) \
|
|
|
|
|
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
|
|
|
|
|
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
|
|
|
|
|
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
|
|
|
|
|
auto* W_##name__##_w0 = scope.FindVar(#name__ ".w_0"); \
|
|
|
|
|
auto* W_##name__##_w1 = scope.FindVar(#name__ ".w_1"); \
|
|
|
|
|
auto* W_##name__##_b0 = scope.FindVar(#name__ ".b_0"); \
|
|
|
|
|
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
|
|
|
|
|
VLOG(4) << #name__ "_w0" \
|
|
|
|
|
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
|
|
|
|
@ -169,26 +169,26 @@ void PrepareParameters(Graph* graph, const Param& param) {
|
|
|
|
|
GATE_W(c);
|
|
|
|
|
#undef GATE_W
|
|
|
|
|
|
|
|
|
|
auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
|
|
|
|
|
auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
|
|
|
|
|
auto* attention_output_w = scope->FindVar("attention_output.w_0");
|
|
|
|
|
auto* attention_output_b = scope->FindVar("attention_output.b_0");
|
|
|
|
|
auto* attention_fc_w = scope.FindVar("attention_fc.w_0");
|
|
|
|
|
auto* attention_fc_b = scope.FindVar("attention_fc.b_0");
|
|
|
|
|
auto* attention_output_w = scope.FindVar("attention_output.w_0");
|
|
|
|
|
auto* attention_output_b = scope.FindVar("attention_output.b_0");
|
|
|
|
|
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
|
|
|
|
|
attention_output_b);
|
|
|
|
|
|
|
|
|
|
auto* lstm_weight = scope->Var(param.LSTMWeight);
|
|
|
|
|
auto* lstm_weight = scope.Var(param.LSTMWeight);
|
|
|
|
|
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
|
|
|
|
|
auto* lstm_bias = scope->Var(param.LSTMBias);
|
|
|
|
|
auto* lstm_bias = scope.Var(param.LSTMBias);
|
|
|
|
|
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
// reshape attention_bias
|
|
|
|
|
auto* attention_bias_t =
|
|
|
|
|
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
|
|
|
|
|
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
|
|
|
|
|
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
|
|
|
|
|
|
|
|
|
|
auto* attention_scalar_bias_t =
|
|
|
|
|
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
|
|
|
|
|
scope.FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
|
|
|
|
|
attention_scalar_bias_t->Resize(
|
|
|
|
|
make_ddim({1, attention_scalar_bias_t->dims()[0]}));
|
|
|
|
|
|
|
|
|
|