|
|
|
@ -24,34 +24,62 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(Input) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("W"),
|
|
|
|
|
"Input(Weight) of LSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("InitH"),
|
|
|
|
|
"Input(init_h) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("InitC"),
|
|
|
|
|
"Input(init_c) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Cache"),
|
|
|
|
|
"Input(Cache) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("last_h"),
|
|
|
|
|
"Output(last_h) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("last_c"),
|
|
|
|
|
"Output(last_c) of LSTM should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Reserve"), "Output", "Reserve", "CudnnLSTM");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("StateOut"), "Output", "StateOut",
|
|
|
|
|
"CudnnLSTM");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CudnnLSTM");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("LastH"), "Output", "LastH", "CudnnLSTM");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM");
|
|
|
|
|
|
|
|
|
|
auto in_dims = ctx->GetInputDim("Input");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3.");
|
|
|
|
|
auto init_dims = ctx->GetInputDim("InitH");
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input in CudnnLSTM must be 3. But "
|
|
|
|
|
"received Input's rank is %d.",
|
|
|
|
|
in_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(init_dims.size(), 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of InitH in CudnnLSTM must be 3. But "
|
|
|
|
|
"received InitH's rank is %d.",
|
|
|
|
|
init_dims.size()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims[1], init_dims[1],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The in_dims[1] (Input dims) and init_dims[1] (InitH "
|
|
|
|
|
"dims) should be equal. But "
|
|
|
|
|
"received in_dims[1] is %d and init_dims[1] is %d.",
|
|
|
|
|
in_dims[1], init_dims[1]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims[2], init_dims[2],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The in_dims[2] (Input dims) and init_dims[2] (InitH "
|
|
|
|
|
"dims) should be equal. But "
|
|
|
|
|
"received in_dims[2] is %d and init_dims[2] is %d.",
|
|
|
|
|
in_dims[2], init_dims[2]));
|
|
|
|
|
|
|
|
|
|
auto out_dims = in_dims;
|
|
|
|
|
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
|
|
|
|
|
out_dims[2] = hidden_size;
|
|
|
|
|
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
|
|
|
|
|
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
|
|
|
|
|
|
|
|
|
|
auto last_dims = init_dims;
|
|
|
|
|
last_dims[0] = is_bidirec ? last_dims[0] * 2 : last_dims[0];
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
|
ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH"));
|
|
|
|
|
ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC"));
|
|
|
|
|
ctx->SetOutputDim("LastH", last_dims);
|
|
|
|
|
ctx->SetOutputDim("LastC", last_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -84,33 +112,31 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"(Tensor) the learnable hidden-hidden weights."
|
|
|
|
|
" The shape is (N), where N is total weight size of the LSTM. "
|
|
|
|
|
" cudnn concatenate all the weight to one Tensor");
|
|
|
|
|
AddInput("Cache",
|
|
|
|
|
"The cache of dropout op, a RAW type variable including random "
|
|
|
|
|
"number generator states and some descriptors, which is used in "
|
|
|
|
|
"cudnn kernel.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput("Reserve",
|
|
|
|
|
"(Tensor, a temporary output Tensor to store the reserve_data "
|
|
|
|
|
"of cudnn kernel.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("StateOut",
|
|
|
|
|
"Share memory with State. "
|
|
|
|
|
"Store the global drop state when training");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor) the hidden state of LSTM operator. "
|
|
|
|
|
"The shape is ( seq_len x batch_size x hidden_size) if "
|
|
|
|
|
"is_bidirec is False"
|
|
|
|
|
"and When is_bidirec is True, the shape will be ( seq_len x "
|
|
|
|
|
"batch_size x hidden_size * 2) ");
|
|
|
|
|
AddOutput("last_h",
|
|
|
|
|
AddOutput("LastH",
|
|
|
|
|
"(Tensor) the hidden state of the last step. "
|
|
|
|
|
"The shape is ( num_layers x batch_size x hidden_size) if "
|
|
|
|
|
"is_bidirec is False"
|
|
|
|
|
"and When is_bidirec is True, the shape will be (num_layers*2 x "
|
|
|
|
|
"batch_size x hidden_size)");
|
|
|
|
|
AddOutput("last_c",
|
|
|
|
|
AddOutput("LastC",
|
|
|
|
|
"(Tensor) the cell state of the last step"
|
|
|
|
|
"The shape is ( num_layers x batch_size x hidden_size) if "
|
|
|
|
|
"is_bidirec is False"
|
|
|
|
|
"and When is_bidirect is True, the shape will be (num_layers*2 x "
|
|
|
|
|
"batch_size x hidden_size*2)");
|
|
|
|
|
AddAttr<int>("max_len",
|
|
|
|
|
"max length of the LSTM op"
|
|
|
|
|
"the first dim of the Input can NOT be greater than max_len")
|
|
|
|
|
.SetDefault(20);
|
|
|
|
|
AddAttr<float>(
|
|
|
|
|
"dropout_prob",
|
|
|
|
|
"dropout prob of the dropout op"
|
|
|
|
@ -120,14 +146,14 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddAttr<bool>("is_bidirec",
|
|
|
|
|
"is_bidirec"
|
|
|
|
|
"if it is bidirectional rnn"
|
|
|
|
|
"The will affect the shape of the Out, last_h, and last_c")
|
|
|
|
|
"The will affect the shape of the Out, LastH, and LastC")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<int>("input_size", "input size ot the Input Tensor").SetDefault(10);
|
|
|
|
|
AddAttr<int>("hidden_size", "hidden size of the LSTM").SetDefault(100);
|
|
|
|
|
AddAttr<int>("num_layers", "the total layer number of the LSTM")
|
|
|
|
|
.SetDefault(1);
|
|
|
|
|
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
|
|
|
|
|
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(-1);
|
|
|
|
|
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
CUDNN LSTM implementation
|
|
|
|
|
|
|
|
|
@ -172,16 +198,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(Input) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Cache"),
|
|
|
|
|
"Input(last_c) of LSTM should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("InitH"),
|
|
|
|
|
"Input(init_h) of LSTM should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("InitC"),
|
|
|
|
|
"Input(init_c) of LSTM should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad");
|
|
|
|
|
|
|
|
|
|
auto SetOutGradDim = [&ctx](const std::string& name) {
|
|
|
|
|
auto g_name = framework::GradVarName(name);
|
|
|
|
@ -195,6 +215,12 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
SetOutGradDim("InitH");
|
|
|
|
|
SetOutGradDim("InitC");
|
|
|
|
|
}
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
|
|
|
|
ctx, framework::GradVarName("Out")),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -209,13 +235,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
op->SetInput("InitH", this->Input("InitH"));
|
|
|
|
|
op->SetInput("InitC", this->Input("InitC"));
|
|
|
|
|
op->SetInput("W", this->Input("W"));
|
|
|
|
|
if (this->HasInput("Cache")) {
|
|
|
|
|
op->SetInput("Cache", this->Input("Cache"));
|
|
|
|
|
}
|
|
|
|
|
op->SetInput("Reserve", this->Output("Reserve"));
|
|
|
|
|
op->SetInput("StateOut", this->Output("StateOut"));
|
|
|
|
|
op->SetInput("Out", this->Output("Out"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
|
|
|
|
op->SetInput(framework::GradVarName("last_c"), this->OutputGrad("last_c"));
|
|
|
|
|
op->SetInput(framework::GradVarName("last_h"), this->OutputGrad("last_h"));
|
|
|
|
|
op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC"));
|
|
|
|
|
op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH"));
|
|
|
|
|
|
|
|
|
|
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
|
|
|
|
|