remove scope in cudnn lstm (#25188)

revert-24895-update_cub
GaoWei8 5 years ago committed by GitHub
parent da29760d58
commit 1fbee267d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,34 +24,62 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
"Input(Input) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM");
PADDLE_ENFORCE(ctx->HasInput("W"), OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
"Input(Weight) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");
PADDLE_ENFORCE(ctx->HasInput("InitH"), OP_INOUT_CHECK(ctx->HasOutput("Reserve"), "Output", "Reserve", "CudnnLSTM");
"Input(init_h) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasOutput("StateOut"), "Output", "StateOut",
PADDLE_ENFORCE(ctx->HasInput("InitC"), "CudnnLSTM");
"Input(init_c) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CudnnLSTM");
PADDLE_ENFORCE(ctx->HasInput("Cache"), OP_INOUT_CHECK(ctx->HasOutput("LastH"), "Output", "LastH", "CudnnLSTM");
"Input(Cache) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("last_h"),
"Output(last_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("last_c"),
"Output(last_c) of LSTM should not be null.");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3."); auto init_dims = ctx->GetInputDim("InitH");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(init_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d.",
init_dims.size()));
PADDLE_ENFORCE_EQ(in_dims[1], init_dims[1],
platform::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_dims[1] is %d.",
in_dims[1], init_dims[1]));
PADDLE_ENFORCE_EQ(in_dims[2], init_dims[2],
platform::errors::InvalidArgument(
"The in_dims[2] (Input dims) and init_dims[2] (InitH "
"dims) should be equal. But "
"received in_dims[2] is %d and init_dims[2] is %d.",
in_dims[2], init_dims[2]));
auto out_dims = in_dims; auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size"); auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
out_dims[2] = hidden_size; bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
auto last_dims = init_dims;
last_dims[0] = is_bidirec ? last_dims[0] * 2 : last_dims[0];
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH")); ctx->SetOutputDim("LastH", last_dims);
ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC")); ctx->SetOutputDim("LastC", last_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
@ -84,33 +112,31 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) the learnable hidden-hidden weights." "(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. " " The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor"); " cudnn concatenate all the weight to one Tensor");
AddInput("Cache", AddOutput("Reserve",
"The cache of dropout op, a RAW type variable including random " "(Tensor, a temporary output Tensor to store the reserve_data "
"number generator states and some descriptors, which is used in " "of cudnn kernel.")
"cudnn kernel.") .AsIntermediate();
.AsDispensable(); AddOutput("StateOut",
"Share memory with State. "
"Store the global drop state when training");
AddOutput("Out", AddOutput("Out",
"(Tensor) the hidden state of LSTM operator. " "(Tensor) the hidden state of LSTM operator. "
"The shape is ( seq_len x batch_size x hidden_size) if " "The shape is ( seq_len x batch_size x hidden_size) if "
"is_bidirec is False" "is_bidirec is False"
"and When is_bidirec is True, the shape will be ( seq_len x " "and When is_bidirec is True, the shape will be ( seq_len x "
"batch_size x hidden_size * 2) "); "batch_size x hidden_size * 2) ");
AddOutput("last_h", AddOutput("LastH",
"(Tensor) the hidden state of the last step. " "(Tensor) the hidden state of the last step. "
"The shape is ( num_layers x batch_size x hidden_size) if " "The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False" "is_bidirec is False"
"and When is_bidirec is True, the shape will be (num_layers*2 x " "and When is_bidirec is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size)"); "batch_size x hidden_size)");
AddOutput("last_c", AddOutput("LastC",
"(Tensor) the cell state of the last step" "(Tensor) the cell state of the last step"
"The shape is ( num_layers x batch_size x hidden_size) if " "The shape is ( num_layers x batch_size x hidden_size) if "
"is_bidirec is False" "is_bidirec is False"
"and When is_bidirect is True, the shape will be (num_layers*2 x " "and When is_bidirect is True, the shape will be (num_layers*2 x "
"batch_size x hidden_size*2)"); "batch_size x hidden_size*2)");
AddAttr<int>("max_len",
"max length of the LSTM op"
"the first dim of the Input can NOT be greater than max_len")
.SetDefault(20);
AddAttr<float>( AddAttr<float>(
"dropout_prob", "dropout_prob",
"dropout prob of the dropout op" "dropout prob of the dropout op"
@ -120,14 +146,14 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("is_bidirec", AddAttr<bool>("is_bidirec",
"is_bidirec" "is_bidirec"
"if it is bidirectional rnn" "if it is bidirectional rnn"
"The will affect the shape of the Out, last_h, and last_c") "The will affect the shape of the Out, LastH, and LastC")
.SetDefault(false); .SetDefault(false);
AddAttr<int>("input_size", "input size ot the Input Tensor").SetDefault(10); AddAttr<int>("input_size", "input size ot the Input Tensor").SetDefault(10);
AddAttr<int>("hidden_size", "hidden size of the LSTM").SetDefault(100); AddAttr<int>("hidden_size", "hidden size of the LSTM").SetDefault(100);
AddAttr<int>("num_layers", "the total layer number of the LSTM") AddAttr<int>("num_layers", "the total layer number of the LSTM")
.SetDefault(1); .SetDefault(1);
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false); AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(-1); AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
CUDNN LSTM implementation CUDNN LSTM implementation
@ -172,16 +198,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad");
"Input(Input) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad");
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null."); OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad");
PADDLE_ENFORCE(ctx->HasInput("Cache"), OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad");
"Input(last_c) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitH"),
"Input(init_h) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("InitC"),
"Input(init_c) of LSTM should not be null.");
auto SetOutGradDim = [&ctx](const std::string& name) { auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name); auto g_name = framework::GradVarName(name);
@ -195,6 +215,12 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
SetOutGradDim("InitH"); SetOutGradDim("InitH");
SetOutGradDim("InitC"); SetOutGradDim("InitC");
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
}; };
template <typename T> template <typename T>
@ -209,13 +235,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("InitH", this->Input("InitH")); op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", this->Input("InitC")); op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", this->Input("W")); op->SetInput("W", this->Input("W"));
if (this->HasInput("Cache")) { op->SetInput("Reserve", this->Output("Reserve"));
op->SetInput("Cache", this->Input("Cache")); op->SetInput("StateOut", this->Output("StateOut"));
}
op->SetInput("Out", this->Output("Out")); op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput(framework::GradVarName("last_c"), this->OutputGrad("last_c")); op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC"));
op->SetInput(framework::GradVarName("last_h"), this->OutputGrad("last_h")); op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -100,6 +100,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnCreateDropoutDescriptor); \ __macro(cudnnCreateDropoutDescriptor); \
__macro(cudnnDropoutGetStatesSize); \ __macro(cudnnDropoutGetStatesSize); \
__macro(cudnnSetDropoutDescriptor); \ __macro(cudnnSetDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnCreateRNNDescriptor); \ __macro(cudnnCreateRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \ __macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \ __macro(cudnnGetRNNWorkspaceSize); \

@ -2213,9 +2213,9 @@ def lstm(input,
input ( :ref:`api_guide_Variable_en` ): LSTM input tensor, 3-D Tensor of shape :math:`[batch\_size, seq\_len, input\_dim]` . Data type is float32 or float64 input ( :ref:`api_guide_Variable_en` ): LSTM input tensor, 3-D Tensor of shape :math:`[batch\_size, seq\_len, input\_dim]` . Data type is float32 or float64
init_h( :ref:`api_guide_Variable_en` ): The initial hidden state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` . init_h( :ref:`api_guide_Variable_en` ): The initial hidden state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` .
If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64. If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64.
max_len (int): This parameter has no effect and will be discarded.
init_c( :ref:`api_guide_Variable_en` ): The initial cell state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` . init_c( :ref:`api_guide_Variable_en` ): The initial cell state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` .
If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64. If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64.
max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len.
hidden_size (int): hidden size of the LSTM. hidden_size (int): hidden size of the LSTM.
num_layers (int): total layers number of the LSTM. num_layers (int): total layers number of the LSTM.
dropout_prob(float, optional): dropout prob, dropout ONLY work between rnn layers, NOT between time steps dropout_prob(float, optional): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
@ -2256,7 +2256,6 @@ def lstm(input,
data = fluid.data(name='x', shape=[None, 100], dtype='int64') data = fluid.data(name='x', shape=[None, 100], dtype='int64')
emb = fluid.embedding(input=data, size=[vocab_size, emb_dim], is_sparse=True) emb = fluid.embedding(input=data, size=[vocab_size, emb_dim], is_sparse=True)
batch_size = 20 batch_size = 20
max_len = 100
dropout_prob = 0.2 dropout_prob = 0.2
input_size = 100 input_size = 100
hidden_size = 150 hidden_size = 150
@ -2309,9 +2308,11 @@ def lstm(input,
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
last_h = helper.create_variable_for_type_inference(dtype) last_h = helper.create_variable_for_type_inference(dtype)
last_c = helper.create_variable_for_type_inference(dtype) last_c = helper.create_variable_for_type_inference(dtype)
reserve = helper.create_variable_for_type_inference(
cache = helper.create_variable( dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
persistable=True, type=core.VarDesc.VarType.RAW, stop_gradient=True) state_out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
state_out.persistable = True
helper.append_op( helper.append_op(
type='cudnn_lstm', type='cudnn_lstm',
@ -2320,15 +2321,15 @@ def lstm(input,
'InitH': init_h, 'InitH': init_h,
'InitC': init_c, 'InitC': init_c,
'W': weight, 'W': weight,
'Cache': cache,
}, },
outputs={ outputs={
'Out': out, 'Out': out,
'last_h': last_h, 'LastH': last_h,
'last_c': last_c, 'LastC': last_c,
'Reserve': reserve,
'StateOut': state_out,
}, },
attrs={ attrs={
'max_len': max_len,
'is_bidirec': is_bidirec, 'is_bidirec': is_bidirec,
'input_size': input_size, 'input_size': input_size,
'hidden_size': hidden_size, 'hidden_size': hidden_size,

@ -20,15 +20,14 @@ import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
SIGMOID_THRESHOLD_MIN = -40.0 SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0 SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0 EXP_MAX_INPUT = 40.0
def lstm_naive( def lstm_naive(input, w):
input,
w, ):
seq_len, batch_size, hidden_size = input.shape seq_len, batch_size, hidden_size = input.shape
offset = 0 offset = 0
@ -86,8 +85,8 @@ def lstm_naive(
return (2. / (1. + np.exp(y))) - 1. return (2. / (1. + np.exp(y))) - 1.
output = [] output = []
pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype) pre_h = np.zeros((1, batch_size, hidden_size), dtype=input.dtype)
pre_c = np.zeros((batch_size, hidden_size), dtype=input.dtype) pre_c = np.zeros((1, batch_size, hidden_size), dtype=input.dtype)
for i in range(seq_len): for i in range(seq_len):
emb_1 = input[i] emb_1 = input[i]
@ -110,7 +109,6 @@ def lstm_naive(
output = np.concatenate(output, -1) output = np.concatenate(output, -1)
output = output.reshape((batch_size, -1, hidden_size)) output = output.reshape((batch_size, -1, hidden_size))
output = output.transpose((1, 0, 2)) output = output.transpose((1, 0, 2))
return output, pre_h, pre_c return output, pre_h, pre_c
@ -119,11 +117,12 @@ def lstm_naive(
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestCUDNNLstmOp(OpTest): class TestCUDNNLstmOp(OpTest):
# TODO(GaoWei8):when input dtype is fp64, precision threshold should be removed.
def setUp(self): def setUp(self):
self.op_type = "cudnn_lstm" self.op_type = "cudnn_lstm"
self.dtype = np.float32 self.dtype = np.float64
num_steps = 20 seq_length = 20
batch_size = 5 batch_size = 5
hidden_size = 20 hidden_size = 20
@ -133,33 +132,24 @@ class TestCUDNNLstmOp(OpTest):
weight_size += hidden_size * 8 weight_size += hidden_size * 8
input = np.random.uniform( input = np.random.uniform(
low=-0.1, high=0.1, size=(num_steps, batch_size, low=-0.1, high=0.1, size=(seq_length, batch_size,
hidden_size)).astype(self.dtype) hidden_size)).astype(self.dtype)
flat_w = np.random.uniform( flat_w = np.random.uniform(
low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype) low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype)
output, last_hidden, last_cell = lstm_naive(input, flat_w) output, last_hidden, last_cell = lstm_naive(input, flat_w)
init_h = np.zeros((batch_size, hidden_size), dtype=np.float32) init_h = np.zeros((1, batch_size, hidden_size), dtype=np.float64)
init_c = np.zeros((batch_size, hidden_size), dtype=np.float32) init_c = np.zeros((1, batch_size, hidden_size), dtype=np.float64)
scope = core.Scope() state_out = np.ndarray((300)).astype("uint8")
program = fluid.Program()
block = program.global_block()
cache_temp = block.create_var(
name="Cache",
persistable=True,
type=core.VarDesc.VarType.RAW,
stop_gradient=True)
self.inputs = { self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input), 'Input': input,
'W': OpTest.np_dtype_to_fluid_dtype(flat_w), 'W': flat_w,
'InitH': OpTest.np_dtype_to_fluid_dtype(init_h), 'InitH': init_h,
'InitC': OpTest.np_dtype_to_fluid_dtype(init_c), 'InitC': init_c
} }
self.cache_name_list = ['Cache']
self.attrs = { self.attrs = {
'max_len': num_steps,
'dropout_prob': 0.0, 'dropout_prob': 0.0,
'is_bidirec': False, 'is_bidirec': False,
'input_size': hidden_size, 'input_size': hidden_size,
@ -168,22 +158,61 @@ class TestCUDNNLstmOp(OpTest):
} }
self.outputs = { self.outputs = {
'Out': output, 'Out': output,
"last_h": last_hidden, "LastH": last_hidden,
'last_c': last_cell 'LastC': last_cell,
'Reserve': np.ndarray((400)).astype("uint8"),
'StateOut': state_out
} }
def test_output_with_place(self): def test_output_with_place(self):
# depend on the scope structure # depend on the scope structure
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5, check_dygraph=False) self.check_output_with_place(
place, no_check_set=['Reserve', 'StateOut'])
def test_grad_with_place(self): def test_grad_with_place(self):
# depend on the scope structure # depend on the scope structure
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
place, place,
set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'last_h', 'last_c'], set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'LastH', 'LastC'],
check_dygraph=False) max_relative_error=1e-4)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNlstmAPI(unittest.TestCase):
def test_lstm(self):
seq_len = 20
batch_size = 5
hidden_size = 20
dropout_prob = 0.0
num_layers = 1
input = fluid.data(
name='input',
shape=[seq_len, batch_size, hidden_size],
dtype='float64')
init_h = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
init_c = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len,
hidden_size, num_layers,
dropout_prob)
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
input_i = np.random.uniform(
low=-0.1, high=0.1, size=(seq_len, batch_size,
hidden_size)).astype("float64")
out = exe.run(fluid.default_main_program(),
feed={'input': input_i},
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'])
output, last_hidden, last_cell = lstm_naive(input_i, out[3])
self.assertTrue(np.allclose(output, out[0], atol=1e-5))
self.assertTrue(np.allclose(last_hidden, out[1], atol=1e-5))
self.assertTrue(np.allclose(last_cell, out[2], atol=1e-5))
if __name__ == '__main__': if __name__ == '__main__':

@ -26,4 +26,5 @@ no_check_set_white_list = [
'cross_entropy2', 'cross_entropy2',
'seed', 'seed',
'amp_check_finite_and_scale', 'amp_check_finite_and_scale',
'cudnn_lstm',
] ]

@ -41,7 +41,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'unpool', \ 'unpool', \
'yolov3_loss', \ 'yolov3_loss', \
'inverse', \ 'inverse', \
'bilateral_slice' 'bilateral_slice',\
'cudnn_lstm'
] ]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp'] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp']

Loading…
Cancel
Save