remove scope in cudnn lstm (#25188)

revert-24895-update_cub
GaoWei8 5 years ago committed by GitHub
parent da29760d58
commit 1fbee267d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,34 +24,62 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitH"),
"Input(init_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitC"),
"Input(init_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cache"),
"Input(Cache) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("last_h"),
"Output(last_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("last_c"),
"Output(last_c) of LSTM should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("Reserve"), "Output", "Reserve", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("StateOut"), "Output", "StateOut",
"CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("LastH"), "Output", "LastH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3.");
auto init_dims = ctx->GetInputDim("InitH");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(init_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d.",
init_dims.size()));
PADDLE_ENFORCE_EQ(in_dims[1], init_dims[1],
platform::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_dims[1] is %d.",
in_dims[1], init_dims[1]));
PADDLE_ENFORCE_EQ(in_dims[2], init_dims[2],
platform::errors::InvalidArgument(
"The in_dims[2] (Input dims) and init_dims[2] (InitH "
"dims) should be equal. But "
"received in_dims[2] is %d and init_dims[2] is %d.",
in_dims[2], init_dims[2]));
auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
out_dims[2] = hidden_size;
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
auto last_dims = init_dims;
last_dims[0] = is_bidirec ? last_dims[0] * 2 : last_dims[0];
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH"));
ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC"));
ctx->SetOutputDim("LastH", last_dims);
ctx->SetOutputDim("LastC", last_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
@ -84,33 +112,31 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor");
AddInput("Cache",
"The cache of dropout op, a RAW type variable including random "
"number generator states and some descriptors, which is used in "
"cudnn kernel.")
.AsDispensable();
AddOutput("Reserve",
"(Tensor, a temporary output Tensor to store the reserve_data "
"of cudnn kernel.")
.AsIntermediate();
AddOutput("StateOut",
"Share memory with State. "
"Store the global drop state when training");
AddOutput("Out",
"(Tensor) the hidden state of LSTM operator. "
"The shape is ( seq_len x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirec is True, the shape will be ( seq_len x "
"batch_size x hidden_size * 2) ");
AddOutput("last_h",
AddOutput("LastH",
"(Tensor) the hidden state of the last step. "
"The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirec is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size)");
AddOutput("last_c",
AddOutput("LastC",
"(Tensor) the cell state of the last step"
"The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False"
"and When is_bidirect is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size*2)");
AddAttr<int>("max_len",
"max length of the LSTM op"
"the first dim of the Input can NOT be greater than max_len")
.SetDefault(20);
AddAttr<float>(
"dropout_prob",
"dropout prob of the dropout op"
@ -120,14 +146,14 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("is_bidirec",
"is_bidirec"
"if it is bidirectional rnn"
"The will affect the shape of the Out, last_h, and last_c")
"The will affect the shape of the Out, LastH, and LastC")
.SetDefault(false);
AddAttr<int>("input_size", "input size ot the Input Tensor").SetDefault(10);
AddAttr<int>("hidden_size", "hidden size of the LSTM").SetDefault(100);
AddAttr<int>("num_layers", "the total layer number of the LSTM")
.SetDefault(1);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(-1);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
AddComment(R"DOC(
CUDNN LSTM implementation
@ -172,16 +198,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cache"),
"Input(last_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitH"),
"Input(init_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitC"),
"Input(init_c) of LSTM should not be null.");
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad");
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
@ -195,6 +215,12 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
SetOutGradDim("InitH");
SetOutGradDim("InitC");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
@ -209,13 +235,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", this->Input("W"));
if (this->HasInput("Cache")) {
op->SetInput("Cache", this->Input("Cache"));
}
op->SetInput("Reserve", this->Output("Reserve"));
op->SetInput("StateOut", this->Output("StateOut"));
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput(framework::GradVarName("last_c"), this->OutputGrad("last_c"));
op->SetInput(framework::GradVarName("last_h"), this->OutputGrad("last_h"));
op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC"));
op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -100,6 +100,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnCreateDropoutDescriptor); \
__macro(cudnnDropoutGetStatesSize); \
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \

@ -2213,9 +2213,9 @@ def lstm(input,
input ( :ref:`api_guide_Variable_en` ): LSTM input tensor, 3-D Tensor of shape :math:`[batch\_size, seq\_len, input\_dim]` . Data type is float32 or float64
init_h( :ref:`api_guide_Variable_en` ): The initial hidden state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` .
If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64.
max_len (int): This parameter has no effect and will be discarded.
init_c( :ref:`api_guide_Variable_en` ): The initial cell state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` .
If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64.
max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len.
hidden_size (int): hidden size of the LSTM.
num_layers (int): total layers number of the LSTM.
dropout_prob(float, optional): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
@ -2256,7 +2256,6 @@ def lstm(input,
data = fluid.data(name='x', shape=[None, 100], dtype='int64')
emb = fluid.embedding(input=data, size=[vocab_size, emb_dim], is_sparse=True)
batch_size = 20
max_len = 100
dropout_prob = 0.2
input_size = 100
hidden_size = 150
@ -2309,9 +2308,11 @@ def lstm(input,
out = helper.create_variable_for_type_inference(dtype)
last_h = helper.create_variable_for_type_inference(dtype)
last_c = helper.create_variable_for_type_inference(dtype)
cache = helper.create_variable(
persistable=True, type=core.VarDesc.VarType.RAW, stop_gradient=True)
reserve = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
state_out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
state_out.persistable = True
helper.append_op(
type='cudnn_lstm',
@ -2320,15 +2321,15 @@ def lstm(input,
'InitH': init_h,
'InitC': init_c,
'W': weight,
'Cache': cache,
},
outputs={
'Out': out,
'last_h': last_h,
'last_c': last_c,
'LastH': last_h,
'LastC': last_c,
'Reserve': reserve,
'StateOut': state_out,
},
attrs={
'max_len': max_len,
'is_bidirec': is_bidirec,
'input_size': input_size,
'hidden_size': hidden_size,

@ -20,15 +20,14 @@ import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle.fluid as fluid
import paddle.fluid.layers as layers
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def lstm_naive(
input,
w, ):
def lstm_naive(input, w):
seq_len, batch_size, hidden_size = input.shape
offset = 0
@ -86,8 +85,8 @@ def lstm_naive(
return (2. / (1. + np.exp(y))) - 1.
output = []
pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype)
pre_c = np.zeros((batch_size, hidden_size), dtype=input.dtype)
pre_h = np.zeros((1, batch_size, hidden_size), dtype=input.dtype)
pre_c = np.zeros((1, batch_size, hidden_size), dtype=input.dtype)
for i in range(seq_len):
emb_1 = input[i]
@ -110,7 +109,6 @@ def lstm_naive(
output = np.concatenate(output, -1)
output = output.reshape((batch_size, -1, hidden_size))
output = output.transpose((1, 0, 2))
return output, pre_h, pre_c
@ -119,11 +117,12 @@ def lstm_naive(
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp(OpTest):
# TODO(GaoWei8):when input dtype is fp64, precision threshold should be removed.
def setUp(self):
self.op_type = "cudnn_lstm"
self.dtype = np.float32
self.dtype = np.float64
num_steps = 20
seq_length = 20
batch_size = 5
hidden_size = 20
@ -133,33 +132,24 @@ class TestCUDNNLstmOp(OpTest):
weight_size += hidden_size * 8
input = np.random.uniform(
low=-0.1, high=0.1, size=(num_steps, batch_size,
low=-0.1, high=0.1, size=(seq_length, batch_size,
hidden_size)).astype(self.dtype)
flat_w = np.random.uniform(
low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype)
output, last_hidden, last_cell = lstm_naive(input, flat_w)
init_h = np.zeros((batch_size, hidden_size), dtype=np.float32)
init_c = np.zeros((batch_size, hidden_size), dtype=np.float32)
scope = core.Scope()
program = fluid.Program()
block = program.global_block()
cache_temp = block.create_var(
name="Cache",
persistable=True,
type=core.VarDesc.VarType.RAW,
stop_gradient=True)
init_h = np.zeros((1, batch_size, hidden_size), dtype=np.float64)
init_c = np.zeros((1, batch_size, hidden_size), dtype=np.float64)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'W': OpTest.np_dtype_to_fluid_dtype(flat_w),
'InitH': OpTest.np_dtype_to_fluid_dtype(init_h),
'InitC': OpTest.np_dtype_to_fluid_dtype(init_c),
'Input': input,
'W': flat_w,
'InitH': init_h,
'InitC': init_c
}
self.cache_name_list = ['Cache']
self.attrs = {
'max_len': num_steps,
'dropout_prob': 0.0,
'is_bidirec': False,
'input_size': hidden_size,
@ -168,22 +158,61 @@ class TestCUDNNLstmOp(OpTest):
}
self.outputs = {
'Out': output,
"last_h": last_hidden,
'last_c': last_cell
"LastH": last_hidden,
'LastC': last_cell,
'Reserve': np.ndarray((400)).astype("uint8"),
'StateOut': state_out
}
def test_output_with_place(self):
# depend on the scope structure
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5, check_dygraph=False)
self.check_output_with_place(
place, no_check_set=['Reserve', 'StateOut'])
def test_grad_with_place(self):
# depend on the scope structure
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'last_h', 'last_c'],
check_dygraph=False)
set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'LastH', 'LastC'],
max_relative_error=1e-4)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNlstmAPI(unittest.TestCase):
def test_lstm(self):
seq_len = 20
batch_size = 5
hidden_size = 20
dropout_prob = 0.0
num_layers = 1
input = fluid.data(
name='input',
shape=[seq_len, batch_size, hidden_size],
dtype='float64')
init_h = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
init_c = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len,
hidden_size, num_layers,
dropout_prob)
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
input_i = np.random.uniform(
low=-0.1, high=0.1, size=(seq_len, batch_size,
hidden_size)).astype("float64")
out = exe.run(fluid.default_main_program(),
feed={'input': input_i},
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'])
output, last_hidden, last_cell = lstm_naive(input_i, out[3])
self.assertTrue(np.allclose(output, out[0], atol=1e-5))
self.assertTrue(np.allclose(last_hidden, out[1], atol=1e-5))
self.assertTrue(np.allclose(last_cell, out[2], atol=1e-5))
if __name__ == '__main__':

@ -26,4 +26,5 @@ no_check_set_white_list = [
'cross_entropy2',
'seed',
'amp_check_finite_and_scale',
'cudnn_lstm',
]

@ -41,7 +41,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'unpool', \
'yolov3_loss', \
'inverse', \
'bilateral_slice'
'bilateral_slice',\
'cudnn_lstm'
]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp']

Loading…
Cancel
Save