Merge pull request #14622 from PaddlePaddle/add_cudnn_lstm
Add cudnn lstmrevert-14666-feature/estiminate_flops
commit
4f71a6ee2c
@ -0,0 +1,218 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class CudnnLSTMOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("W"),
|
||||
"Input(Weight) of LSTM should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasInput("InitH"),
|
||||
"Input(init_h) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("InitC"),
|
||||
"Input(init_c) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Cache"),
|
||||
"Input(Cache) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("last_h"),
|
||||
"Output(last_h) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("last_c"),
|
||||
"Output(last_c) of LSTM should not be null.");
|
||||
|
||||
auto in_dims = ctx->GetInputDim("Input");
|
||||
PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3.");
|
||||
|
||||
ctx->SetOutputDim("Out", ctx->GetInputDim("Input"));
|
||||
ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH"));
|
||||
ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC"));
|
||||
}
|
||||
};
|
||||
|
||||
class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"Input",
|
||||
"(Tensor) RNN input tensor, which support variable-time length input "
|
||||
"sequence."
|
||||
"The shape of the Tensor MUST be ( seq_len * batch_size * input_size)"
|
||||
"seq_len is the total time step in this mini-batch (CAN be change in "
|
||||
"different batch)"
|
||||
"batch_size is the instance number of this batch"
|
||||
"input_size is the hidden size of the input."
|
||||
"input_hidden_size and the hidden_size in the next may not be same");
|
||||
AddInput("InitH",
|
||||
"(Tensor) the initial hidden state of the LSTM"
|
||||
"input. This is a tensor with shape (num_layers x batch_size x "
|
||||
"hidden_size)"
|
||||
"and When is_bidirec is True, the shape will be (num_layers*2 x "
|
||||
"batch_size x hidden_size)");
|
||||
AddInput("InitC",
|
||||
"(Tensor) the initial cell state of the LSTm "
|
||||
"input. This is a tensor with shape (num_layers x batch_size x "
|
||||
"hidden_size)"
|
||||
"and When is_bidirec is True, the shape will be (num_layers*2 x "
|
||||
"batch_size x hidden_size)");
|
||||
AddInput("W",
|
||||
"(Tensor) the learnable hidden-hidden weights."
|
||||
" The shape is (N), where N is total weight size of the LSTM. "
|
||||
" cudnn concatenate all the weight to one Tensor");
|
||||
AddInput("Cache",
|
||||
"The cache of dropout op, a RAW type variable including random "
|
||||
"number generator states and some descriptors, which is used in "
|
||||
"cudnn kernel.")
|
||||
.AsDispensable();
|
||||
AddOutput("Out",
|
||||
"(Tensor) the hidden state of LSTM operator. "
|
||||
"The shape is ( seq_len x batch_size x hidden_size) if "
|
||||
"is_bidirec is False"
|
||||
"and When is_bidirec is True, the shape will be ( seq_len x "
|
||||
"batch_size x hidden_size * 2) ");
|
||||
AddOutput("last_h",
|
||||
"(Tensor) the hidden state of the last step. "
|
||||
"The shape is ( num_layers x batch_size x hidden_size) if "
|
||||
"is_bidirec is False"
|
||||
"and When is_bidirec is True, the shape will be (num_layers*2 x "
|
||||
"batch_size x hidden_size)");
|
||||
AddOutput("last_c",
|
||||
"(Tensor) the cell state of the last step"
|
||||
"The shape is ( num_layers x batch_size x hidden_size) if "
|
||||
"is_bidirec is False"
|
||||
"and When is_bidirect is True, the shape will be (num_layers*2 x "
|
||||
"batch_size x hidden_size*2)");
|
||||
AddAttr<int>("max_len",
|
||||
"max length of the LSTM op"
|
||||
"the first dim of the Input can NOT be greater than max_len")
|
||||
.SetDefault(20);
|
||||
AddAttr<float>(
|
||||
"dropout_prob",
|
||||
"dropout prob of the dropout op"
|
||||
"the dropout ONLY work between lstm layers, not between time steps"
|
||||
"There is no dropout work on the Out tensor")
|
||||
.SetDefault(0.0);
|
||||
AddAttr<bool>("is_bidirec",
|
||||
"is_bidirec"
|
||||
"if it is bidirection rnn"
|
||||
"The will affect the shape of the Out, last_h, and last_c")
|
||||
.SetDefault(false);
|
||||
AddAttr<int>("input_size", "input size ot the Input Tensor").SetDefault(10);
|
||||
AddAttr<int>("hidden_size", "hidden size of the LSTM").SetDefault(100);
|
||||
AddAttr<int>("num_layers", "the total layer number of the LSTM")
|
||||
.SetDefault(1);
|
||||
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
|
||||
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(-1);
|
||||
AddComment(R"DOC(
|
||||
CUDNN LSTM implementation
|
||||
|
||||
A four-gate Long Short-Term Memory network with no peephole connections.
|
||||
In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1,
|
||||
the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations:
|
||||
|
||||
$$ i_t = sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$
|
||||
|
||||
$$ f_t = sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) $$
|
||||
|
||||
$$ o_t = sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) $$
|
||||
|
||||
$$ \\tilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) $$
|
||||
|
||||
$$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$
|
||||
|
||||
$$ h_t = o_t \\odot tanh(c_t) $$
|
||||
|
||||
- W terms denote weight matrices (e.g. $W_{ix}$ is the matrix
|
||||
of weights from the input gate to the input)
|
||||
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
|
||||
- sigmoid is the logistic sigmoid function.
|
||||
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
|
||||
and cell activation vectors, respectively, all of which have the same size as
|
||||
the cell output activation vector $h$.
|
||||
- The $\odot$ is the element-wise product of the vectors.
|
||||
- `tanh` is the activation functions.
|
||||
- $\tilde{c_t}$ is also called candidate hidden state,
|
||||
which is computed based on the current input and the previous hidden state.
|
||||
|
||||
Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication,
|
||||
X represensts a matrix multiplication
|
||||
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CudnnLSTMGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||
"Input(Input) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("last_h"),
|
||||
"Input(last_h) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("last_c"),
|
||||
"Input(last_c) of LSTM should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasInput("Cache"),
|
||||
"Input(last_c) of LSTM should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("InitH"),
|
||||
"Input(init_h) of LSTM should not be null.");
|
||||
|
||||
PADDLE_ENFORCE(ctx->HasInput("InitC"),
|
||||
"Input(init_c) of LSTM should not be null.");
|
||||
|
||||
auto SetOutGradDim = [&ctx](const std::string& name) {
|
||||
auto g_name = framework::GradVarName(name);
|
||||
if (ctx->HasOutput(g_name)) {
|
||||
ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
|
||||
}
|
||||
};
|
||||
|
||||
SetOutGradDim("Input");
|
||||
SetOutGradDim("W");
|
||||
SetOutGradDim("InitH");
|
||||
SetOutGradDim("InitC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class NotImpleKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_THROW(
|
||||
"CPU is not support for this kernel now. Will be add in the future");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>);
|
||||
REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel<float>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,192 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
|
||||
SIGMOID_THRESHOLD_MIN = -40.0
|
||||
SIGMOID_THRESHOLD_MAX = 13.0
|
||||
EXP_MAX_INPUT = 40.0
|
||||
|
||||
|
||||
def lstm_naive(
|
||||
input,
|
||||
w, ):
|
||||
seq_len, batch_size, hidden_size = input.shape
|
||||
|
||||
offset = 0
|
||||
wi = w[offset:offset + hidden_size * hidden_size].reshape(
|
||||
(hidden_size, hidden_size)).transpose()
|
||||
offset += hidden_size * hidden_size
|
||||
wf = w[offset:offset + hidden_size * hidden_size].reshape(
|
||||
(hidden_size, hidden_size)).transpose()
|
||||
offset += hidden_size * hidden_size
|
||||
wc = w[offset:offset + hidden_size * hidden_size].reshape(
|
||||
(hidden_size, hidden_size)).transpose()
|
||||
offset += hidden_size * hidden_size
|
||||
wo = w[offset:offset + hidden_size * hidden_size].reshape(
|
||||
(hidden_size, hidden_size)).transpose()
|
||||
offset += hidden_size * hidden_size
|
||||
ri = w[offset:offset + hidden_size * hidden_size].reshape(
|
||||
(hidden_size, hidden_size)).transpose()
|
||||
offset += hidden_size * hidden_size
|
||||
rf = w[offset:offset + hidden_size * hidden_size].reshape(
|
||||
(hidden_size, hidden_size)).transpose()
|
||||
offset += hidden_size * hidden_size
|
||||
rc = w[offset:offset + hidden_size * hidden_size].reshape(
|
||||
(hidden_size, hidden_size)).transpose()
|
||||
offset += hidden_size * hidden_size
|
||||
ro = w[offset:offset + hidden_size * hidden_size].reshape(
|
||||
(hidden_size, hidden_size)).transpose()
|
||||
offset += hidden_size * hidden_size
|
||||
|
||||
bi_1 = w[offset:offset + hidden_size]
|
||||
offset += hidden_size
|
||||
bf_1 = w[offset:offset + hidden_size]
|
||||
offset += hidden_size
|
||||
bc_1 = w[offset:offset + hidden_size]
|
||||
offset += hidden_size
|
||||
bo_1 = w[offset:offset + hidden_size]
|
||||
offset += hidden_size
|
||||
|
||||
bi_2 = w[offset:offset + hidden_size]
|
||||
offset += hidden_size
|
||||
bf_2 = w[offset:offset + hidden_size]
|
||||
offset += hidden_size
|
||||
bc_2 = w[offset:offset + hidden_size]
|
||||
offset += hidden_size
|
||||
bo_2 = w[offset:offset + hidden_size]
|
||||
|
||||
def sigmoid(x):
|
||||
y = np.copy(x)
|
||||
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
|
||||
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
|
||||
return 1. / (1. + np.exp(-y))
|
||||
|
||||
def tanh(x):
|
||||
y = -2. * x
|
||||
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
|
||||
return (2. / (1. + np.exp(y))) - 1.
|
||||
|
||||
output = []
|
||||
pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype)
|
||||
pre_c = np.zeros((batch_size, hidden_size), dtype=input.dtype)
|
||||
|
||||
for i in range(seq_len):
|
||||
emb_1 = input[i]
|
||||
|
||||
input_gate = sigmoid(
|
||||
np.matmul(emb_1, wi) + np.matmul(pre_h, ri) + bi_1 + bi_2)
|
||||
forget_gate = sigmoid(
|
||||
np.matmul(emb_1, wf) + np.matmul(pre_h, rf) + bf_1 + bf_2)
|
||||
output_gate = sigmoid(
|
||||
np.matmul(emb_1, wo) + np.matmul(pre_h, ro) + bo_1 + bo_2)
|
||||
c_t_temp = tanh(
|
||||
np.matmul(emb_1, wc) + np.matmul(pre_h, rc) + bc_1 + bc_2)
|
||||
new_c = input_gate * c_t_temp + forget_gate * pre_c
|
||||
new_h = output_gate * tanh(new_c)
|
||||
|
||||
pre_h = new_h
|
||||
pre_c = new_c
|
||||
|
||||
output.append(new_h)
|
||||
|
||||
output = np.concatenate(output, -1)
|
||||
output = output.reshape((batch_size, -1, hidden_size))
|
||||
|
||||
output = output.transpose((1, 0, 2))
|
||||
|
||||
return output, pre_h, pre_c
|
||||
|
||||
|
||||
class TestCUDNNLstmOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "cudnn_lstm"
|
||||
self.dtype = np.float32
|
||||
|
||||
num_steps = 20
|
||||
batch_size = 5
|
||||
hidden_size = 20
|
||||
|
||||
input_weight_size = (hidden_size * hidden_size) * 4
|
||||
hidden_weight_size = (hidden_size * hidden_size) * 4
|
||||
weight_size = input_weight_size + hidden_weight_size
|
||||
weight_size += hidden_size * 8
|
||||
|
||||
input = np.random.uniform(
|
||||
low=-0.1, high=0.1, size=(num_steps, batch_size,
|
||||
hidden_size)).astype(self.dtype)
|
||||
flat_w = np.random.uniform(
|
||||
low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype)
|
||||
|
||||
output, last_hidden, last_cell = lstm_naive(input, flat_w)
|
||||
|
||||
init_h = np.zeros((batch_size, hidden_size), dtype=np.float32)
|
||||
init_c = np.zeros((batch_size, hidden_size), dtype=np.float32)
|
||||
scope = core.Scope()
|
||||
program = fluid.Program()
|
||||
block = program.global_block()
|
||||
|
||||
cache_temp = block.create_var(
|
||||
name="Cache",
|
||||
persistable=True,
|
||||
type=core.VarDesc.VarType.RAW,
|
||||
stop_gradient=True)
|
||||
self.inputs = {
|
||||
'Input': OpTest.np_dtype_to_fluid_dtype(input),
|
||||
'W': OpTest.np_dtype_to_fluid_dtype(flat_w),
|
||||
'InitH': OpTest.np_dtype_to_fluid_dtype(init_h),
|
||||
'InitC': OpTest.np_dtype_to_fluid_dtype(init_c),
|
||||
}
|
||||
self.cache_name_list = ['Cache']
|
||||
self.attrs = {
|
||||
'max_len': num_steps,
|
||||
'dropout_prob': 0.0,
|
||||
'is_bidirec': False,
|
||||
'input_size': hidden_size,
|
||||
'hidden_size': hidden_size,
|
||||
'num_layers': 1,
|
||||
}
|
||||
self.outputs = {
|
||||
'Out': output,
|
||||
"last_h": last_hidden,
|
||||
'last_c': last_cell
|
||||
}
|
||||
|
||||
def test_output_with_place(self):
|
||||
if self.testcuda():
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_output_with_place(place, atol=1e-5)
|
||||
|
||||
def test_grad_with_place(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
place = core.CUDAPlace(0)
|
||||
self.check_grad_with_place(
|
||||
place,
|
||||
set(['Input', 'W', 'InitH', 'InitC']),
|
||||
['Out', 'last_h', 'last_c'],
|
||||
max_relative_error=0.02)
|
||||
|
||||
def testcuda(self):
|
||||
return core.is_compiled_with_cuda()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue