|
|
@ -26,86 +26,102 @@ namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
|
|
|
|
"Input(X) of AttentionLSTM should not be null.");
|
|
|
|
"Input(WeightX) of LSTM should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("C0"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
|
|
|
|
"Input(C0) of AttentionLSTM should not be null.");
|
|
|
|
"Input(WeightH) of LSTM should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Bias"),
|
|
|
|
"Input(LSTMWeight) of AttentionLSTM should not be null.");
|
|
|
|
"Input(Bias) of LSTM should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
|
|
|
|
|
|
|
|
"Input(LSTMBias) of AttentionLSTM should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XX"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
|
|
|
|
"Output(XX) of LSTM should not be null.");
|
|
|
|
"Input(AttentionWeight) of AttentionLSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
"Output(Hidden) of LSTM should not be null.");
|
|
|
|
"Output(Hidden) of AttentionLSTM should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
|
|
|
"Output(Cell) of LSTM should not be null.");
|
|
|
|
"Output(Cell) of AttentionLSTM should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
|
|
|
|
"Output(BatchedGate) of LSTM should not be null.");
|
|
|
|
"Output(AttentionedX) of AttentionLSTM should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
|
|
|
|
"Output(BatchedGate) of LSTM should not be null.");
|
|
|
|
"Output(AttentionFCOut) of AttentionLSTM should not be null.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
|
|
|
|
|
|
|
|
"Output(LSTMX) of AttentionLSTM should not be null.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
|
|
|
|
|
|
|
|
"Output(LSTMOUT) of AttentionLSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
|
|
|
const int M = x_dims[1];
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto w_dims = ctx->GetInputDim("LSTMWeight");
|
|
|
|
|
|
|
|
const int D = w_dims[1] / 4;
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
|
|
|
|
|
|
|
|
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto b_dims = ctx->GetInputDim("LSTMBias");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M,
|
|
|
|
|
|
|
|
D);
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).",
|
|
|
|
|
|
|
|
M, D);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto c_dims = ctx->GetInputDim("C0");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("C0"),
|
|
|
|
|
|
|
|
"Input(Cell) and Input(Hidden) of LSTM should not "
|
|
|
|
|
|
|
|
"be null at the same time.");
|
|
|
|
|
|
|
|
auto h_dims = ctx->GetInputDim("H0");
|
|
|
|
auto h_dims = ctx->GetInputDim("H0");
|
|
|
|
auto c_dims = ctx->GetInputDim("C0");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(h_dims == c_dims,
|
|
|
|
PADDLE_ENFORCE(h_dims == c_dims,
|
|
|
|
"The dimension of Input(H0) and Input(C0) "
|
|
|
|
"The dimension of Input(H0) and Input(C0) "
|
|
|
|
"should be the same.");
|
|
|
|
"should be the same.");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// fc_out , shape (maxseqlen,1)
|
|
|
|
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
|
|
|
|
int max_seq_len = 0;
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
|
|
|
|
|
|
|
|
"Input(AttentionWeight)'s rank must be 2.");
|
|
|
|
auto wx_dims = ctx->GetInputDim("WeightX");
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
|
|
|
|
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
"The rank of Input(WeightX) should be 2.");
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
|
|
|
|
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
"The first dimension of Input(WeightX) "
|
|
|
|
if (ctx->HasInput("AttentionBias")) {
|
|
|
|
"should be %d.",
|
|
|
|
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
|
|
|
|
x_dims[1]);
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
|
|
|
|
|
|
|
|
"Input(AttentionBias)'s rank must be 2.");
|
|
|
|
int frame_size = wx_dims[1] / 4;
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
|
|
|
|
auto wh_dims = ctx->GetInputDim("WeightH");
|
|
|
|
"AttentionBias shapes must be 1 * 1.");
|
|
|
|
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
|
|
|
|
"The rank of Input(WeightH) should be 2.");
|
|
|
|
"AttentionBias shapes must be 1 * 1.");
|
|
|
|
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
|
|
|
|
}
|
|
|
|
"The first dimension of Input(WeightH) "
|
|
|
|
|
|
|
|
"should be %d.",
|
|
|
|
if (ctx->HasInput("AttentionScalar")) {
|
|
|
|
frame_size);
|
|
|
|
auto dims = ctx->GetInputDim("AttentionScalar");
|
|
|
|
PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2,
|
|
|
|
"The second dimension of Input(WeightH) "
|
|
|
|
"Input(AttentionScalar)'s rank must be 2.");
|
|
|
|
"should be 4 * %d.",
|
|
|
|
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
|
|
|
|
frame_size);
|
|
|
|
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
|
|
|
|
|
|
|
|
}
|
|
|
|
auto b_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
|
|
|
|
if (ctx->HasInput("AttentionScalarBias")) {
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1,
|
|
|
|
auto dims = ctx->GetInputDim("AttentionScalarBias");
|
|
|
|
"The first dimension of Input(Bias) should be 1.");
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
|
|
|
ctx->HasInput("AttentionScalar"),
|
|
|
|
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_peepholes"),
|
|
|
|
"AttentionScalar should not be null when have AttentionScalarBias.");
|
|
|
|
"Do not support peephole yet.");
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2,
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
|
|
|
|
"Input(AttentionScalarBias)'s rank must be 2.");
|
|
|
|
"The second dimension of Input(Bias) should be "
|
|
|
|
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
|
|
|
|
"4 * %d if disable peepholes connection",
|
|
|
|
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
|
|
|
|
frame_size);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims({x_dims[0], frame_size});
|
|
|
|
framework::DDim out_dims({x_dims[0], D});
|
|
|
|
ctx->SetOutputDim("Hidden", out_dims);
|
|
|
|
ctx->SetOutputDim("Hidden", out_dims);
|
|
|
|
ctx->SetOutputDim("Cell", out_dims);
|
|
|
|
ctx->SetOutputDim("Cell", out_dims);
|
|
|
|
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
|
|
|
|
ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
|
|
|
|
ctx->SetOutputDim("BatchCellPreAct", out_dims);
|
|
|
|
ctx->SetOutputDim("LSTMX", {1, M});
|
|
|
|
|
|
|
|
ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
|
|
|
|
|
|
|
|
// AttentionFCOut should be reshape as (maxseqlen,1) in runtime
|
|
|
|
ctx->ShareLoD("X", "Hidden");
|
|
|
|
ctx->ShareLoD("X", "Hidden");
|
|
|
|
ctx->ShareLoD("X", "Cell");
|
|
|
|
ctx->ShareLoD("X", "Cell");
|
|
|
|
|
|
|
|
|
|
|
|
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
|
|
|
|
|
|
|
|
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
|
|
|
|
|
|
|
|
ctx->ShareLoD("X", "XX");
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
|
|
|
|
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
|
|
|
@ -164,9 +180,8 @@ void AttentionLSTMOpMaker::Make() {
|
|
|
|
AddOutput("Cell",
|
|
|
|
AddOutput("Cell",
|
|
|
|
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
|
|
|
|
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
|
|
|
|
"The shape is (T x D), and lod is the same with the `Input`.");
|
|
|
|
"The shape is (T x D), and lod is the same with the `Input`.");
|
|
|
|
AddOutput(
|
|
|
|
AddOutput("AttentionedX",
|
|
|
|
"AttentionedX",
|
|
|
|
"(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
|
|
|
|
"(LodTensor) shape is (T x 1), the result after X * AttentionWeight,"
|
|
|
|
|
|
|
|
" where T is the total time steps in this mini-batch,"
|
|
|
|
" where T is the total time steps in this mini-batch,"
|
|
|
|
" D is the hidden size.")
|
|
|
|
" D is the hidden size.")
|
|
|
|
.AsIntermediate();
|
|
|
|
.AsIntermediate();
|
|
|
@ -318,11 +333,30 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
|
|
|
|
auto* atted_x = ctx.Output<LoDTensor>("AttentionedX"); // T x 1
|
|
|
|
auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1
|
|
|
|
auto* fc_out = ctx.Output<Tensor>('AttentionFCOut'); // max_seq_len x 1
|
|
|
|
auto* fc_out = ctx.Output<Tensor>('AttentionFCOut'); // max_seq_len x 1
|
|
|
|
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
|
|
|
|
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
|
|
|
|
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
|
|
|
|
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// some shape should be reshape here since infershape can not get lod info
|
|
|
|
|
|
|
|
auto x_lod = x->lod();
|
|
|
|
|
|
|
|
const int N = x_lod[0].size() - 1; // batch size
|
|
|
|
|
|
|
|
auto x_dims = x->dims(); // T x M
|
|
|
|
|
|
|
|
auto w_dims = w->dims(); // (D+M) x 4D
|
|
|
|
|
|
|
|
const int M = x_dims[1]; // x frame size
|
|
|
|
|
|
|
|
const int D = w_dims[1] / 4; // gate frame size
|
|
|
|
|
|
|
|
const int D2 = D * 2;
|
|
|
|
|
|
|
|
const int D3 = D * 3;
|
|
|
|
|
|
|
|
const int D4 = w_dims[1];
|
|
|
|
|
|
|
|
int max_seq_len = x_lod[0][1];
|
|
|
|
|
|
|
|
for (int i = 1; i < N; ++i) {
|
|
|
|
|
|
|
|
int len = x_lod[0][i + 1] - x_lod[0][i];
|
|
|
|
|
|
|
|
max_seq_len = max_seq_len < len ? len : max_seq_len;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
|
|
|
|
|
|
|
|
fc_out->Resize({max_seq_len, 1});
|
|
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
const T* h0_data = h0->data<T>();
|
|
|
|
const T* h0_data = h0->data<T>();
|
|
|
|
const T* c0_data = c0->data<T>();
|
|
|
|
const T* c0_data = c0->data<T>();
|
|
|
@ -341,16 +375,6 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
T* lstm_x_data = lstm_x->mutable_data<T>();
|
|
|
|
T* lstm_x_data = lstm_x->mutable_data<T>();
|
|
|
|
T* lstm_out_data = lstm_out->mutable_data<T>();
|
|
|
|
T* lstm_out_data = lstm_out->mutable_data<T>();
|
|
|
|
|
|
|
|
|
|
|
|
auto x_lod = x->lod();
|
|
|
|
|
|
|
|
auto x_dims = x->dims(); // T x M
|
|
|
|
|
|
|
|
auto w_dims = w->dims(); // (D+M) x 4D
|
|
|
|
|
|
|
|
const int M = x_dims[1]; // x frame size
|
|
|
|
|
|
|
|
const int D = w_dims[1] / 4; // gate frame size
|
|
|
|
|
|
|
|
const int D2 = D * 2;
|
|
|
|
|
|
|
|
const int D3 = D * 3;
|
|
|
|
|
|
|
|
const int D4 = w_dims[1];
|
|
|
|
|
|
|
|
const int batch_size = x_lod[0].size() - 1; // assert lod.size() == 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
|
|
|
|
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, T, 1, M, x_data, atten_w_data,
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, T, 1, M, x_data, atten_w_data,
|
|
|
@ -361,7 +385,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
const T* prev_hidden_data = NULL;
|
|
|
|
const T* prev_hidden_data = NULL;
|
|
|
|
T* cur_cell_out_data = cell_out_data;
|
|
|
|
T* cur_cell_out_data = cell_out_data;
|
|
|
|
T* cur_hidden_out_data = hidden_out_data;
|
|
|
|
T* cur_hidden_out_data = hidden_out_data;
|
|
|
|
for (int i = 0; i < batch_size; ++i) {
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
|
|
int seq_len = x_lod[0][i + 1];
|
|
|
|
int seq_len = x_lod[0][i + 1];
|
|
|
|
prev_cell_data = c0_data + i * D;
|
|
|
|
prev_cell_data = c0_data + i * D;
|
|
|
|
prev_hidden_data = h0 ? h0_data + i * D : NULL;
|
|
|
|
prev_hidden_data = h0 ? h0_data + i * D : NULL;
|
|
|
@ -370,13 +394,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
/// compute attention vector
|
|
|
|
/// compute attention vector
|
|
|
|
// prev_cell(1xD) * fc(D) rest part of atten_wgt
|
|
|
|
// prev_cell(1xD) * fc(D) rest part of atten_wgt
|
|
|
|
// T = cblas_dot();
|
|
|
|
// T = cblas_dot();
|
|
|
|
T prev_cell_bias = blas.VDOT(D, prev_cell_data, atten_w_data + M);
|
|
|
|
T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
|
|
|
|
// add cell bias and relu
|
|
|
|
// add cell bias and relu
|
|
|
|
bias_relu<T>(seq_len, atted_x_data, &prev_cell_bias, fc_out_data);
|
|
|
|
bias_relu<T>(seq_len, atted_x_data, &prev_cell_bias, fc_out_data);
|
|
|
|
// fc2: scalar
|
|
|
|
// fc2: scalar
|
|
|
|
if (atten_scalar_data) {
|
|
|
|
if (atten_scalar_data) {
|
|
|
|
// x = a*x
|
|
|
|
// x = a*x
|
|
|
|
blas.VSCAL(seq_len, atten_scalar_data, fc_out_data);
|
|
|
|
blas.SCAL(seq_len, atten_scalar_data, fc_out_data);
|
|
|
|
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
|
|
|
|
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
|
|
|
|
fc_out_data);
|
|
|
|
fc_out_data);
|
|
|
|
}
|
|
|
|
}
|
|
|
|