@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel {
public :
using framework : : OperatorWithKernel : : OperatorWithKernel ;
protected :
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " Input " ) ,
" Input(Input) of LSTM should not be null. " ) ;
@ -29,9 +28,13 @@ class LSTMOp : public framework::OperatorWithKernel {
" Output(Hidden) of LSTM should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " Cell " ) ,
" Output(Cell) of LSTM should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " BatchGate " ) ,
" Output(BatchGate) of LSTM should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " BatchCellPreAct " ) ,
" Output(BatchGate) of LSTM should not be null. " ) ;
auto x_dims = ctx - > GetInputDim ( " Input " ) ;
PADDLE_ENFORCE_EQ ( x_dims . size ( ) , 2 , " Input(X)'s rank must be 2. " ) ;
auto in _dims = ctx - > GetInputDim ( " Input " ) ;
PADDLE_ENFORCE_EQ ( in _dims. size ( ) , 2 , " Input(X)'s rank must be 2. " ) ;
if ( ctx - > HasInput ( " H0 " ) ) {
PADDLE_ENFORCE ( ctx - > HasInput ( " C0 " ) ,
@ -44,7 +47,7 @@ class LSTMOp : public framework::OperatorWithKernel {
" should be the same. " ) ;
}
int frame_size = x _dims[ 1 ] / 4 ;
int frame_size = in _dims[ 1 ] / 4 ;
auto w_dims = ctx - > GetInputDim ( " Weight " ) ;
PADDLE_ENFORCE_EQ ( w_dims . size ( ) , 2 ,
" The rank of Input(Weight) should be 2. " ) ;
@ -71,12 +74,21 @@ class LSTMOp : public framework::OperatorWithKernel {
" 4 * %d if disable peepholes connection " ,
frame_size ) ;
}
ctx - > SetOutputDim ( " Hidden " , { x_dims [ 0 ] , frame_size } ) ;
ctx - > SetOutputDim ( " Cell " , { x_dims [ 0 ] , frame_size } ) ;
ctx - > SetOutputDim ( " BatchGate " , x_dims ) ;
framework : : DDim out_dims ( { in_dims [ 0 ] , frame_size } ) ;
ctx - > SetOutputDim ( " Hidden " , out_dims ) ;
ctx - > SetOutputDim ( " Cell " , out_dims ) ;
ctx - > SetOutputDim ( " BatchGate " , in_dims ) ;
ctx - > SetOutputDim ( " BatchCellPreAct " , out_dims ) ;
ctx - > ShareLoD ( " Input " , " Hidden " ) ;
ctx - > ShareLoD ( " Input " , " Cell " ) ;
}
protected :
framework : : DataType IndicateDataType (
const framework : : ExecutionContext & ctx ) const override {
return framework : : ToDataType (
ctx . Input < framework : : LoDTensor > ( " Input " ) - > type ( ) ) ;
}
} ;
class LSTMOpMaker : public framework : : OpProtoAndCheckerMaker {
@ -86,16 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput ( " Input " ,
" (LoDTensor) the first input is a LodTensor, which support "
" variable-time length input sequence. The underlying tensor in "
" this LoDTensor is a matrix with shape (T X 4D), where , T is the "
" this LoDTensor is a matrix with shape (T X 4D), where T is the "
" total time steps in this mini-batch, D is the hidden size. " ) ;
AddInput ( " H0 " ,
" (Tensor, optional) the initial hidden state is an optional "
" input. This is a tensor with shape (N x D), where N is the "
" batch size, D is the hidden size. " ) ;
" batch size, D is the hidden size. " )
. AsDispensable ( ) ;
AddInput ( " C0 " ,
" (Tensor, optional) the initial cell state is an optional "
" input. This is a tensor with shape (N x D), where N is the "
" batch size. `H0` and `C0` can be NULL but only at the same time " ) ;
" batch size. `H0` and `C0` can be NULL but only at the same time " )
. AsDispensable ( ) ;
AddInput ( " Weight " ,
" (Tensor) the learnable hidden-hidden weights. "
" - The shape is (D x 4D), where D is the hidden size. "
@ -109,22 +123,27 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
" - Bias = {b_c, b_i, b_f, b_o}. "
" 2. `usePeepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}. " ) ;
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}. " )
. AsDispensable ( ) ;
AddOutput ( " Hidden " ,
" (LoDTensor) the hidden state of LSTM operator. "
" The shape is (T x D), and lod is the same with the `Input`. " ) ;
AddOutput ( " Cell " ,
" (LoDTensor) the cell state of LSTM operator. "
" The shape is (T x D), and lod is the same with the `Input`. " ) ;
AddOutput ( " BatchGate " ,
" (LoDTensor) This LoDTensor contains input gate, forget gate "
" and output gate after the nonlinear computation. This "
" LoDTensor has the same shape with the reorganized input, which "
" was also be called batch input. The LoD size is 2. The first "
" i s also be called batch input. The LoD size is 2. The first "
" LoD is the batch offsets and the second LoD contains the "
" indexes, which denote the position of reorganized sequence "
" in the raw input. " )
. AsIntermediate ( ) ;
AddOutput ( " Hidden " ,
" (LoDTensor) the hidden state lod tensor of LSTM operator. "
" The shape and lod is the same with the `Input`. " ) ;
AddOutput ( " Cell " ,
" (LoDTensor) the cell state lod tensor of LSTM operator. "
" The shape and lod is the same with the `Input`. " ) ;
AddOutput ( " BatchCellPreAct " ,
" (LoDTensor) This LoDTensor is got in the forward and used "
" in the backward. " )
. AsIntermediate ( ) ;
AddAttr < bool > ( " usePeepholes " ,
" (bool, defalut: True) "
" whether to enable diagonal/peephole connections. " )
@ -202,15 +221,37 @@ class LSTMGradOp : public framework::OperatorWithKernel {
public :
using framework : : OperatorWithKernel : : OperatorWithKernel ;
protected :
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( framework : : GradVarName ( " Hidden " ) ) ,
" Input(Hidden@GRAD) should not be null " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( framework : : GradVarName ( " Cell " ) ) ,
" Input(Cell@GRAD) should not be null " ) ;
ctx - > SetOutputDim ( framework : : GradVarName ( " Weight " ) ,
ctx - > GetInputDim ( " Weight " ) ) ;
ctx - > SetOutputDim ( framework : : GradVarName ( " Bias " ) , ctx - > GetInputDim ( " Bias " ) ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " Input " ) ,
" Input(Input) of LSTM should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " Hidden " ) ,
" Input(Hidden) of LSTM should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " Cell " ) ,
" Input(Cell) of LSTM should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " BatchGate " ) ,
" Input(BatchGate) of LSTM should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " BatchCellPreAct " ) ,
" Input(BatchGate) of LSTM should not be null. " ) ;
auto in_g_name = framework : : GradVarName ( " Input " ) ;
if ( ctx - > HasOutput ( in_g_name ) )
ctx - > SetOutputDim ( in_g_name , ctx - > GetInputDim ( " Input " ) ) ;
auto w_g_name = framework : : GradVarName ( " Weight " ) ;
if ( ctx - > HasOutput ( w_g_name ) )
ctx - > SetOutputDim ( w_g_name , ctx - > GetInputDim ( " Weight " ) ) ;
auto b_g_name = framework : : GradVarName ( " Bias " ) ;
if ( ctx - > HasOutput ( b_g_name ) )
ctx - > SetOutputDim ( b_g_name , ctx - > GetInputDim ( " Bias " ) ) ;
}
protected :
framework : : DataType IndicateDataType (
const framework : : ExecutionContext & ctx ) const override {
return framework : : ToDataType (
ctx . Input < framework : : LoDTensor > ( " Input " ) - > type ( ) ) ;
}
} ;