|
|
|
@ -59,10 +59,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
|
|
|
|
|
auto b_dims = ctx->GetInputDim("LSTMBias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M,
|
|
|
|
|
D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).",
|
|
|
|
|
M, D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
|
|
|
|
|
|
|
|
|
|
auto c_dims = ctx->GetInputDim("C0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
|
|
|
|
@ -148,8 +146,8 @@ void AttentionLSTMOpMaker::Make() {
|
|
|
|
|
"(Tensor) the weights of attention fc. Always relu the fc result."
|
|
|
|
|
"The shape is ((M+D) x 1), where M is the dim size of x, D is the "
|
|
|
|
|
"gate size of LSTM.");
|
|
|
|
|
AddInput("AttentionBias, optional",
|
|
|
|
|
"(Tensor) the bias of attention fc."
|
|
|
|
|
AddInput("AttentionBias",
|
|
|
|
|
"(Tensor, optional) the bias of attention fc."
|
|
|
|
|
"The shape is (1 x 1)")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddInput("AttentionScalar",
|
|
|
|
@ -281,7 +279,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* atten_w = ctx.Input<Tensor>("AttentionWeight"); // (M+D) x 1
|
|
|
|
|
auto* atten_b = ctx.Input<Tensor>("AttentionBias"); // 1x1
|
|
|
|
|
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar"); // 1x1
|
|
|
|
|
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalar"); // 1x1
|
|
|
|
|
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias"); // 1x1
|
|
|
|
|
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4
|
|
|
|
|
auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4
|
|
|
|
|
|
|
|
|
@ -319,7 +317,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* h0_data = h0->data<T>();
|
|
|
|
|
const T* h0_data = h0 ? h0->data<T>() : NULL;
|
|
|
|
|
const T* c0_data = c0->data<T>();
|
|
|
|
|
const T* lstm_w_data = lstm_w->data<T>();
|
|
|
|
|
const T* lstm_b_data = lstm_b->data<T>();
|
|
|
|
@ -341,36 +339,35 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
|
|
|
|
|
atted_x_data, atten_b_data);
|
|
|
|
|
|
|
|
|
|
const T* cur_atten_x_data = atted_x_data;
|
|
|
|
|
const T* cur_x_data = x_data;
|
|
|
|
|
const T* prev_cell_data = NULL;
|
|
|
|
|
const T* prev_hidden_data = NULL;
|
|
|
|
|
T* cur_cell_out_data = cell_out_data;
|
|
|
|
|
T* cur_hidden_out_data = hidden_out_data;
|
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
|
|
|
int seq_len = x_lod[0][i + 1];
|
|
|
|
|
int seq_len = x_lod[0][i + 1] - x_lod[0][i];
|
|
|
|
|
prev_cell_data = c0_data + i * D;
|
|
|
|
|
prev_hidden_data = h0 ? h0_data + i * D : NULL;
|
|
|
|
|
|
|
|
|
|
prev_hidden_data = h0_data ? h0_data + i * D : NULL;
|
|
|
|
|
for (int step = 0; step < seq_len; ++step) {
|
|
|
|
|
/// compute attention vector
|
|
|
|
|
// prev_cell(1xD) * fc(D) rest part of atten_wgt
|
|
|
|
|
// T = cblas_dot();
|
|
|
|
|
/// 1. compute attention vector
|
|
|
|
|
// 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
|
|
|
|
|
T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
|
|
|
|
|
// add cell bias and relu
|
|
|
|
|
bias_relu<T>(seq_len, atted_x_data, &prev_cell_bias, fc_out_data);
|
|
|
|
|
// fc2: scalar
|
|
|
|
|
// 1b. add cell bias and relu
|
|
|
|
|
bias_relu<T>(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data);
|
|
|
|
|
// 1c. fc scalar
|
|
|
|
|
if (atten_scalar_data) {
|
|
|
|
|
// x = a*x
|
|
|
|
|
blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
|
|
|
|
|
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
|
|
|
|
|
fc_out_data);
|
|
|
|
|
}
|
|
|
|
|
// 1d. softmax
|
|
|
|
|
vec_softmax<DeviceContext, T>(blas, seq_len, fc_out_data, fc_out_data);
|
|
|
|
|
// mul x(seq_len*M) and sum pool
|
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
|
|
|
|
|
cur_x_data, lstm_x_data);
|
|
|
|
|
|
|
|
|
|
/// compute LSTM step
|
|
|
|
|
/// 2. compute LSTM step
|
|
|
|
|
// lstm weight : concat[forget , input , output , tilde]
|
|
|
|
|
// shape : (D + M) x (4 * D)
|
|
|
|
|
// fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D
|
|
|
|
@ -407,6 +404,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
cur_hidden_out_data = cur_hidden_out_data + D;
|
|
|
|
|
}
|
|
|
|
|
cur_x_data = cur_x_data + seq_len * M;
|
|
|
|
|
cur_atten_x_data = cur_atten_x_data + seq_len;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|