|
|
|
@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(Input) of LSTM should not be null.");
|
|
|
|
@ -30,8 +29,8 @@ class LSTMOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
|
|
|
|
|
"Output(Cell) of LSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("Input");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
|
auto in_dims = ctx->GetInputDim("Input");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("C0"),
|
|
|
|
@ -44,7 +43,7 @@ class LSTMOp : public framework::OperatorWithKernel {
|
|
|
|
|
"should be the same.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int frame_size = x_dims[1] / 4;
|
|
|
|
|
int frame_size = in_dims[1] / 4;
|
|
|
|
|
auto w_dims = ctx->GetInputDim("Weight");
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
|
|
|
|
|
"The rank of Input(Weight) should be 2.");
|
|
|
|
@ -71,9 +70,11 @@ class LSTMOp : public framework::OperatorWithKernel {
|
|
|
|
|
"4 * %d if disable peepholes connection",
|
|
|
|
|
frame_size);
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
|
|
|
|
|
ctx->SetOutputDim("Cell", {x_dims[0], frame_size});
|
|
|
|
|
ctx->SetOutputDim("BatchGate", x_dims);
|
|
|
|
|
framework::DDim out_dims({in_dims[0], frame_size});
|
|
|
|
|
ctx->SetOutputDim("Hidden", out_dims);
|
|
|
|
|
ctx->SetOutputDim("Cell", out_dims);
|
|
|
|
|
ctx->SetOutputDim("BatchGate", in_dims);
|
|
|
|
|
ctx->SetOutputDim("BatchCellPreAct", out_dims);
|
|
|
|
|
ctx->ShareLoD("Input", "Hidden");
|
|
|
|
|
ctx->ShareLoD("Input", "Cell");
|
|
|
|
|
}
|
|
|
|
@ -86,7 +87,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("Input",
|
|
|
|
|
"(LoDTensor) the first input is a LodTensor, which support "
|
|
|
|
|
"variable-time length input sequence. The underlying tensor in "
|
|
|
|
|
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
|
|
|
|
|
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
|
|
|
|
|
"total time steps in this mini-batch, D is the hidden size.");
|
|
|
|
|
AddInput("H0",
|
|
|
|
|
"(Tensor, optional) the initial hidden state is an optional "
|
|
|
|
@ -110,21 +111,25 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"2. `usePeepholes = True` "
|
|
|
|
|
" - The shape is (1 x 7D). "
|
|
|
|
|
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
|
|
|
|
|
AddOutput("Hidden",
|
|
|
|
|
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
|
|
|
|
|
"The shape and lod is the same with the `Input`.");
|
|
|
|
|
AddOutput("Cell",
|
|
|
|
|
"(LoDTensor) the cell state lod tensor of LSTM operator. "
|
|
|
|
|
"The shape and lod is the same with the `Input`.");
|
|
|
|
|
AddOutput("BatchGate",
|
|
|
|
|
"(LoDTensor) This LoDTensor contains input gate, forget gate "
|
|
|
|
|
"and output gate after the nonlinear computation. This "
|
|
|
|
|
"LoDTensor has the same shape with the reorganized input, which "
|
|
|
|
|
"was also be called batch input. The LoD size is 2. The first "
|
|
|
|
|
"is also be called batch input. The LoD size is 2. The first "
|
|
|
|
|
"LoD is the batch offsets and the second LoD contains the "
|
|
|
|
|
"indexes, which denote the position of reorganized sequence "
|
|
|
|
|
"in the raw input.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("Hidden",
|
|
|
|
|
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
|
|
|
|
|
"The shape and lod is the same with the `Input`.");
|
|
|
|
|
AddOutput("Cell",
|
|
|
|
|
"(LoDTensor) the cell state lod tensor of LSTM operator. "
|
|
|
|
|
"The shape and lod is the same with the `Input`.");
|
|
|
|
|
AddOutput("BatchCellPreAct",
|
|
|
|
|
"(LoDTensor) This LoDTensor is get in the forward and used "
|
|
|
|
|
"in the backward.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddAttr<bool>("usePeepholes",
|
|
|
|
|
"(bool, defalut: True) "
|
|
|
|
|
"whether to enable diagonal/peephole connections.")
|
|
|
|
@ -202,15 +207,28 @@ class LSTMGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
|
|
|
|
|
"Input(Hidden@GRAD) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
|
|
|
|
|
"Input(Cell@GRAD) should not be null");
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Weight"),
|
|
|
|
|
ctx->GetInputDim("Weight"));
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Input"),
|
|
|
|
|
ctx->GetInputDim("Input"));
|
|
|
|
|
if (ctx->HasInput("Weight")) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Weight"),
|
|
|
|
|
ctx->GetInputDim("Weight"));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("Bias")) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Bias"),
|
|
|
|
|
ctx->GetInputDim("Bias"));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("H0")) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("H0"), ctx->GetInputDim("H0"));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("C0")) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("C0"), ctx->GetInputDim("C0"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|