|
|
|
@ -23,23 +23,31 @@ class LstmUnitOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("C_prev"),
|
|
|
|
|
"Input(C_prev) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("C"),
|
|
|
|
|
"Output(C) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("H"),
|
|
|
|
|
"Output(H) of LSTM should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lstm_unit");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("C_prev"), "Input", "C_prev", "lstm_unit");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("C"), "Output", "C", "lstm_unit");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("H"), "Output", "H", "lstm_unit");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto c_prev_dims = ctx->GetInputDim("C_prev");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X)'s rank must be 2. Received %d instead.", x_dims.size()));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0],
|
|
|
|
|
"Batch size of inputs and states must be equal");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Batch size of inputs and states must be equal, "
|
|
|
|
|
"but received %d (inputs)"
|
|
|
|
|
"vs %d (states).",
|
|
|
|
|
x_dims[0], c_prev_dims[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4,
|
|
|
|
|
"Dimension of FC should equal to prev state * 4");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Dimension of FC should equal to prev state * 4, "
|
|
|
|
|
"but received %d (dimension of FC)"
|
|
|
|
|
"vs %d (prev state * 4).",
|
|
|
|
|
x_dims[1], c_prev_dims[1] * 4));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int b_size = c_prev_dims[0]; // batch size
|
|
|
|
@ -85,10 +93,10 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("C")),
|
|
|
|
|
"Input(C@GRAD) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("H")),
|
|
|
|
|
"Input(H@GRAD) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("C")), "Input",
|
|
|
|
|
framework::GradVarName("C"), "lstm_unit");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("H")), "Input",
|
|
|
|
|
framework::GradVarName("H"), "lstm_unit");
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("C_prev"),
|
|
|
|
|
ctx->GetInputDim("C_prev"));
|
|
|
|
|