|
|
|
@ -24,74 +24,92 @@ class LSTMPOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(Input) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
|
|
|
|
"Input(Weight) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
|
|
|
|
|
"Input(ProjWeight) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Bias"),
|
|
|
|
|
"Input(Bias) of LSTMP operator should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Projection"),
|
|
|
|
|
"Output(Projection) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
|
|
|
|
"Output(Cell) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
|
|
|
|
|
"Output(BatchGate) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
|
|
|
|
|
"Output(BatchCellPreAct) of LSTMP operator should not be "
|
|
|
|
|
"null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
|
|
|
|
|
"Output(BatchHidden) of LSTMP operator should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTMP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTMP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("ProjWeight"), "Input", "ProjWeight", "LSTMP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTMP");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Projection"), "Output", "Projection",
|
|
|
|
|
"LSTMP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTMP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTMP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
|
|
|
|
|
"BatchCellPreAct", "LSTMP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
|
|
|
|
|
"LSTMP");
|
|
|
|
|
|
|
|
|
|
auto in_dims = ctx->GetInputDim("Input");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims.size(), 2,
|
|
|
|
|
"Input(X)'s rank of LSTMP operator must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X)'s rank of LSTMP operator must be 2, but received %d.",
|
|
|
|
|
in_dims.size()));
|
|
|
|
|
|
|
|
|
|
int frame_size = in_dims[1] / 4;
|
|
|
|
|
auto w_dims = ctx->GetInputDim("Weight");
|
|
|
|
|
auto proj_dims = ctx->GetInputDim("ProjWeight");
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
|
|
|
|
|
"The rank of Input(Weight) should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims[0], proj_dims[1],
|
|
|
|
|
"The first dimension of Input(Weight) "
|
|
|
|
|
"should be %d.",
|
|
|
|
|
proj_dims[1]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
w_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(Weight) should be 2, but received %d.",
|
|
|
|
|
w_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
w_dims[0], proj_dims[1],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of Input(Weight) and the second dimension of "
|
|
|
|
|
"Input(ProjWeight) should be the same, but received %d vs %d.",
|
|
|
|
|
w_dims[0], proj_dims[1]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
|
|
|
|
|
"The second dimension of Input(Weight) "
|
|
|
|
|
"should be 4 * %d.",
|
|
|
|
|
frame_size);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(proj_dims.size(), 2,
|
|
|
|
|
"The rank of Input(ProjWeight) should be 2.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(Weight) should be 4 * "
|
|
|
|
|
"%d, but received %d.",
|
|
|
|
|
frame_size, w_dims[1]));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
proj_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(ProjWeight) should be 2, but received %d.",
|
|
|
|
|
proj_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(proj_dims[0], frame_size,
|
|
|
|
|
"The first dimension of Input(ProjWeight) "
|
|
|
|
|
"should be %d.",
|
|
|
|
|
frame_size);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of Input(ProjWeight) should be "
|
|
|
|
|
"%d, but received %d.",
|
|
|
|
|
frame_size, proj_dims[0]));
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("C0"),
|
|
|
|
|
"Input(C0) of LSTMP operator should not be null after "
|
|
|
|
|
"Input(H0) provided.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("C0"), true,
|
|
|
|
|
platform::errors::NotFound("Input(C0) of LSTMP operator should not "
|
|
|
|
|
"be null after Input(H0) provided."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto b_dims = ctx->GetInputDim("Bias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[0], 1,
|
|
|
|
|
"The first dimension of Input(Bias) should be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
b_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(Bias) should be 2, but received %d.",
|
|
|
|
|
b_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
b_dims[0], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of Input(Bias) should be 1, but received %d.",
|
|
|
|
|
b_dims[0]));
|
|
|
|
|
|
|
|
|
|
if (ctx->Attrs().Get<bool>("use_peepholes")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
|
|
|
|
|
"The second dimension of Input(Bias) should be "
|
|
|
|
|
"7 * %d if enable peepholes connection",
|
|
|
|
|
frame_size);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
b_dims[1], 7 * frame_size,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(Bias) should be 7 * %d if enable "
|
|
|
|
|
"peepholes connection, but received %d.",
|
|
|
|
|
frame_size, b_dims[1]));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
|
|
|
|
|
"The second dimension of Input(Bias) should be "
|
|
|
|
|
"4 * %d if disable peepholes connection",
|
|
|
|
|
frame_size);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
b_dims[1], 4 * frame_size,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(Bias) should be 4 * %d if disable "
|
|
|
|
|
"peepholes connection, but received %d.",
|
|
|
|
|
frame_size, b_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims({in_dims[0], frame_size});
|
|
|
|
@ -314,21 +332,18 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Projection"),
|
|
|
|
|
"Input(Projection) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Cell"),
|
|
|
|
|
"Input(Cell) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Weight"),
|
|
|
|
|
"Input(Weight) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
|
|
|
|
|
"Input(ProjWeight) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Bias"),
|
|
|
|
|
"Input(Bias) of LSTMP operator should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
|
|
|
|
|
"Input(BatchGate) of LSTMP operator should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
|
|
|
|
|
"Input(BatchGate) of LSTMP operator should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Projection"), "Input", "Projection",
|
|
|
|
|
"LSTMP@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Cell"), "Input", "Cell", "LSTMP@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTMP@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("ProjWeight"), "Input", "ProjWeight",
|
|
|
|
|
"LSTMP@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTMP@Grad");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
|
|
|
|
|
"LSTMP@Grad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("BatchCellPreAct"), "Input", "BatchCellPreAct",
|
|
|
|
|
"LSTMP@Grad");
|
|
|
|
|
|
|
|
|
|
auto SetOutGradDim = [&ctx](const std::string& name) {
|
|
|
|
|
auto g_name = framework::GradVarName(name);
|
|
|
|
|