|
|
|
@ -28,6 +28,10 @@ class LSTMOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Output(Hidden) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
|
|
|
|
"Output(Cell) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
|
|
|
|
|
"Output(BatchGate) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
|
|
|
|
|
"Output(BatchGate) of LSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
auto in_dims = ctx->GetInputDim("Input");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
@ -92,11 +96,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("H0",
|
|
|
|
|
"(Tensor, optional) the initial hidden state is an optional "
|
|
|
|
|
"input. This is a tensor with shape (N x D), where N is the "
|
|
|
|
|
"batch size, D is the hidden size.");
|
|
|
|
|
"batch size, D is the hidden size.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddInput("C0",
|
|
|
|
|
"(Tensor, optional) the initial cell state is an optional "
|
|
|
|
|
"input. This is a tensor with shape (N x D), where N is the "
|
|
|
|
|
"batch size. `H0` and `C0` can be NULL but only at the same time");
|
|
|
|
|
"batch size. `H0` and `C0` can be NULL but only at the same time")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddInput("Weight",
|
|
|
|
|
"(Tensor) the learnable hidden-hidden weights."
|
|
|
|
|
" - The shape is (D x 4D), where D is the hidden size. "
|
|
|
|
@ -110,7 +116,8 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
" - Bias = {b_c, b_i, b_f, b_o}."
|
|
|
|
|
"2. `usePeepholes = True` "
|
|
|
|
|
" - The shape is (1 x 7D). "
|
|
|
|
|
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
|
|
|
|
|
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput("Hidden",
|
|
|
|
|
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
|
|
|
|
|
"The shape and lod is the same with the `Input`.");
|
|
|
|
@ -208,27 +215,29 @@ class LSTMGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
|
|
|
|
|
"Input(Hidden@GRAD) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
|
|
|
|
|
"Input(Cell@GRAD) should not be null");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Input"),
|
|
|
|
|
ctx->GetInputDim("Input"));
|
|
|
|
|
if (ctx->HasInput("Weight")) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Weight"),
|
|
|
|
|
ctx->GetInputDim("Weight"));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Bias"),
|
|
|
|
|
ctx->GetInputDim("Bias"));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("H0"), ctx->GetInputDim("H0"));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("C0")) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("C0"), ctx->GetInputDim("C0"));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(Input) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
|
|
|
|
|
"Input(Hidden) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Cell"),
|
|
|
|
|
"Input(Cell) of LSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
|
|
|
|
|
"Input(BatchGate) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
|
|
|
|
|
"Input(BatchGate) of LSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
auto in_g_name = framework::GradVarName("Input");
|
|
|
|
|
if (ctx->HasOutput(in_g_name))
|
|
|
|
|
ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input"));
|
|
|
|
|
|
|
|
|
|
auto w_g_name = framework::GradVarName("Weight");
|
|
|
|
|
if (ctx->HasOutput(w_g_name))
|
|
|
|
|
ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight"));
|
|
|
|
|
|
|
|
|
|
auto b_g_name = framework::GradVarName("Bias");
|
|
|
|
|
if (ctx->HasOutput(b_g_name))
|
|
|
|
|
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|