|
|
|
@ -23,97 +23,119 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Assert only one Input(X) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("C0"),
|
|
|
|
|
"Assert only one Input(C0) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
|
|
|
|
|
"Assert only one Input(LSTMWeight) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
|
|
|
|
|
"Assert only one Input(LSTMBias) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
|
|
|
|
|
"Assert only one Input(AttentionWeight) of AttentionLSTM.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
|
|
|
|
|
"Assert only one Output(Hidden) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
|
|
|
|
"Assert only one Output(Cell) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
|
|
|
|
|
"Assert only one Output(AttentionedX) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
|
|
|
|
|
"Assert only one Output(AttentionFCOut) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
|
|
|
|
|
"Assert only one Output(LSTMX) of AttentionLSTM.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
|
|
|
|
|
"Assert only one Output(LSTMOUT) of AttentionLSTM.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AttentionLstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "AttentionLstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("LSTMWeight"), "Input", "LSTMWeight",
|
|
|
|
|
"AttentionLstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("LSTMBias"), "Input", "LSTMBias",
|
|
|
|
|
"AttentionLstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("AttentionWeight"), "Input", "AttentionWeight",
|
|
|
|
|
"AttentionLstm");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "AttentionLstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "AttentionLstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("AttentionedX"), "Output", "AttentionedX",
|
|
|
|
|
"AttentionLstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("AttentionFCOut"), "Output", "AttentionFCOut",
|
|
|
|
|
"AttentionLstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("LSTMX"), "Output", "LSTMX", "AttentionLstm");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("LSTMOUT"), "Output", "LSTMOUT",
|
|
|
|
|
"AttentionLstm");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
const int M = x_dims[1];
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X)'s rank must be 2."));
|
|
|
|
|
|
|
|
|
|
auto w_dims = ctx->GetInputDim("LSTMWeight");
|
|
|
|
|
const int D = w_dims[1] / 4;
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
|
|
|
|
|
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
w_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument("Input(LSTMWeight)'s rank must be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
w_dims[0], D + M,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D));
|
|
|
|
|
|
|
|
|
|
auto b_dims = ctx->GetInputDim("LSTMBias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(LSTMBias)'s rank must be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"LSTMBias dims should be 1 x %d.", 4 * D));
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"LSTMBias dims should be 1 x %d.", 4 * D));
|
|
|
|
|
|
|
|
|
|
auto c_dims = ctx->GetInputDim("C0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims.size(), 2, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(C0)'s rank must be 2."));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(c_dims[1], D, platform::errors::InvalidArgument(
|
|
|
|
|
"C0 dims should be N x %d.", D));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
auto h_dims = ctx->GetInputDim("H0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(H0)'s rank must be 2."));
|
|
|
|
|
if (ctx->IsRuntime() ||
|
|
|
|
|
(framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) {
|
|
|
|
|
PADDLE_ENFORCE(h_dims == c_dims,
|
|
|
|
|
PADDLE_ENFORCE_EQ(h_dims, c_dims,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The dimension of Input(H0) and Input(C0) "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
"should be the same."));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
|
|
|
|
|
"Input(AttentionWeight)'s rank must be 2.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(AttentionWeight)'s rank must be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionWeight shapes must be (%d + %d) * 1.", M, D));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("AttentionBias")) {
|
|
|
|
|
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
|
|
|
|
|
"Input(AttentionBias)'s rank must be 2.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(AttentionBias)'s rank must be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
|
|
|
|
|
"AttentionBias shapes must be 1 * 1.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionBias shapes must be 1 * 1."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
|
|
|
|
|
"AttentionBias shapes must be 1 * 1.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionBias shapes must be 1 * 1."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("AttentionScalar")) {
|
|
|
|
|
auto dims = ctx->GetInputDim("AttentionScalar");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2,
|
|
|
|
|
"Input(AttentionScalar)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(AttentionScalar)'s rank must be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[0], 1, platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionScalar shapes must be 1 * 1."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[1], 1, platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionScalar shapes must be 1 * 1."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("AttentionScalarBias")) {
|
|
|
|
|
auto dims = ctx->GetInputDim("AttentionScalarBias");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("AttentionScalar"),
|
|
|
|
|
"AttentionScalar should not be null when have AttentionScalarBias.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("AttentionScalar"), "Input", "AttentionScalar",
|
|
|
|
|
"AttentionLstm");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2,
|
|
|
|
|
"Input(AttentionScalarBias)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(AttentionScalarBias)'s rank must be 2."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[0], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionScalarBias shapes must be 1 * 1."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims[1], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"AttentionScalarBias shapes must be 1 * 1."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims({x_dims[0], D});
|
|
|
|
@ -301,8 +323,11 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
int len = x_lod[0][i + 1] - x_lod[0][i];
|
|
|
|
|
max_seq_len = max_seq_len < len ? len : max_seq_len;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL, "Input(X)'s lod size must be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_lod.size(), 1UL, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X)'s lod size must be 1."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
c0->dims()[0], N,
|
|
|
|
|
platform::errors::InvalidArgument("C0 dims should be %d x %d.", N, D));
|
|
|
|
|
fc_out->Resize({max_seq_len, 1});
|
|
|
|
|
|
|
|
|
|
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
|
|
|
|
|