* Add rnn_op.
test=develop

* Fix rnn_op grad maker's drop_empty_grad.
test=develop
TCChenlong-patch-1
Guo Sheng 5 years ago committed by GitHub
parent 0f4b6247c8
commit 9a600df373
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -361,6 +361,12 @@ class ScopedDropoutDescriptor {
float dropout_prob_, float dropout_prob_,
framework::Tensor* dropout_state_, framework::Tensor* dropout_state_,
int seed, size_t state_size) { int seed, size_t state_size) {
if (dropout_state_ == nullptr) { // for no dropout or test
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetDropoutDescriptor(
desc_, handle, 0 /* dropout */, nullptr, 0 /* state_size */,
0 /* seed */));
return desc_;
}
auto* dropout_state_data = dropout_state_->data<uint8_t>(); auto* dropout_state_data = dropout_state_->data<uint8_t>();
if (!initialized) { if (!initialized) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetDropoutDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetDropoutDescriptor(

@ -93,10 +93,14 @@ class TestSimpleRNN(unittest.TestCase):
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_predict(self):
predict_test_util(self.place, "SimpleRNN")
def runTest(self): def runTest(self):
self.test_with_initial_state() self.test_with_initial_state()
self.test_with_zero_state() self.test_with_zero_state()
self.test_with_input_lengths() self.test_with_input_lengths()
self.test_predict()
class TestGRU(unittest.TestCase): class TestGRU(unittest.TestCase):
@ -175,10 +179,14 @@ class TestGRU(unittest.TestCase):
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
def test_predict(self):
predict_test_util(self.place, "GRU")
def runTest(self): def runTest(self):
self.test_with_initial_state() self.test_with_initial_state()
self.test_with_zero_state() self.test_with_zero_state()
self.test_with_input_lengths() self.test_with_input_lengths()
self.test_predict()
class TestLSTM(unittest.TestCase): class TestLSTM(unittest.TestCase):
@ -258,18 +266,31 @@ class TestLSTM(unittest.TestCase):
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
def test_predict(self): def test_predict(self):
place = paddle.set_device(self.place) predict_test_util(self.place, "LSTM")
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
self.test_predict()
def predict_test_util(place, mode):
place = paddle.set_device(place)
paddle.seed(123) paddle.seed(123)
np.random.seed(123) np.random.seed(123)
class Net(paddle.nn.Layer): class Net(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.rnn1 = paddle.nn.LSTM( self.rnn = getattr(paddle.nn, mode)(16,
16, 32, 2, direction="bidirectional", dropout=0.1) 32,
2,
direction="bidirectional",
dropout=0.1)
def forward(self, input): def forward(self, input):
return self.rnn1(input) return self.rnn(input)
x = paddle.randn((4, 10, 16)) x = paddle.randn((4, 10, 16))
x.stop_gradient = False x.stop_gradient = False
@ -277,7 +298,7 @@ class TestLSTM(unittest.TestCase):
mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype) mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype)
mask = paddle.unsqueeze(mask, [2]) mask = paddle.unsqueeze(mask, [2])
rnn = Net() rnn = Net()
y, (h, c) = rnn(x) y, _ = rnn(x)
y = y * mask y = y * mask
loss = paddle.mean(y) loss = paddle.mean(y)
loss.backward() loss.backward()
@ -285,16 +306,15 @@ class TestLSTM(unittest.TestCase):
learning_rate=0.1, parameters=rnn.parameters()) learning_rate=0.1, parameters=rnn.parameters())
optimizer.step() optimizer.step()
rnn.eval() rnn.eval()
y, (h, c) = rnn(x) y, _ = rnn(x)
# `jit.to_static` would include a train_program, eval mode might cause # `jit.to_static` would include a train_program, eval mode might cause
# some errors currently, such as dropout grad op gets `is_test == True`. # some errors currently, such as dropout grad op gets `is_test == True`.
rnn.train() rnn.train()
rnn = paddle.jit.to_static( rnn = paddle.jit.to_static(
rnn, rnn, [paddle.static.InputSpec(
[paddle.static.InputSpec(
shape=[None, None, 16], dtype=x.dtype)]) shape=[None, None, 16], dtype=x.dtype)])
paddle.jit.save(rnn, "./inference/lstm_infer") paddle.jit.save(rnn, "./inference/%s_infer" % mode)
paddle.enable_static() paddle.enable_static()
@ -305,8 +325,8 @@ class TestLSTM(unittest.TestCase):
fetch_targets] = paddle.static.load_inference_model( fetch_targets] = paddle.static.load_inference_model(
dirname="./inference", dirname="./inference",
executor=exe, executor=exe,
model_filename="lstm_infer.pdmodel", model_filename="%s_infer.pdmodel" % mode,
params_filename="lstm_infer.pdiparams") params_filename="%s_infer.pdiparams" % mode)
results = exe.run(inference_program, results = exe.run(inference_program,
feed={feed_target_names[0]: x.numpy()}, feed={feed_target_names[0]: x.numpy()},
fetch_list=fetch_targets) fetch_list=fetch_targets)
@ -314,12 +334,6 @@ class TestLSTM(unittest.TestCase):
y.numpy(), results[0]) # eval results equal predict results y.numpy(), results[0]) # eval results equal predict results
paddle.disable_static() paddle.disable_static()
def runTest(self):
self.test_with_initial_state()
self.test_with_zero_state()
self.test_with_input_lengths()
self.test_predict()
def load_tests(loader, tests, pattern): def load_tests(loader, tests, pattern):
suite = unittest.TestSuite() suite = unittest.TestSuite()

@ -990,7 +990,6 @@ class RNNBase(LayerList):
self.could_use_cudnn &= direction != "backward" self.could_use_cudnn &= direction != "backward"
self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * ( self.could_use_cudnn &= len(self.parameters()) == num_layers * 4 * (
2 if direction == "bidirectional" else 1) 2 if direction == "bidirectional" else 1)
self.could_use_cudnn &= mode == "LSTM" # currently only support LSTM
# Expose params as RNN's attribute, which can make it compatible when # Expose params as RNN's attribute, which can make it compatible when
# replacing small ops composed rnn with cpp rnn kernel. # replacing small ops composed rnn with cpp rnn kernel.
@ -1062,22 +1061,18 @@ class RNNBase(LayerList):
def _cudnn_impl(self, inputs, initial_states, sequence_length): def _cudnn_impl(self, inputs, initial_states, sequence_length):
if not self.time_major: if not self.time_major:
inputs = paddle.tensor.transpose(inputs, [1, 0, 2]) inputs = paddle.tensor.transpose(inputs, [1, 0, 2])
# unify LSTM/GRU/SimpleRNN later, currently only support LSTM
# TODO(guosheng): use `core.ops.cudnn_lstm` in dygraph mode if support
# specify output, since `dropout_state` should be a persistable tensor
# rather than a temporary on.
out = self._helper.create_variable_for_type_inference(inputs.dtype) out = self._helper.create_variable_for_type_inference(inputs.dtype)
last_h = self._helper.create_variable_for_type_inference(inputs.dtype) state = [
last_c = self._helper.create_variable_for_type_inference(inputs.dtype) self._helper.create_variable_for_type_inference(inputs.dtype)
for i in range(self.state_components)
]
reserve = self._helper.create_variable_for_type_inference( reserve = self._helper.create_variable_for_type_inference(
dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True) dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True)
inputs = { inputs = {
'Input': inputs, 'Input': inputs,
# 'W': self._flat_weight, # would be unused_var
'WeightList': self._all_weights, 'WeightList': self._all_weights,
'InitH': initial_states[0], 'PreState': initial_states,
'InitC': initial_states[1],
'SequenceLength': sequence_length 'SequenceLength': sequence_length
} }
attrs = { attrs = {
@ -1086,23 +1081,22 @@ class RNNBase(LayerList):
'input_size': self.input_size, 'input_size': self.input_size,
'hidden_size': self.hidden_size, 'hidden_size': self.hidden_size,
'num_layers': self.num_layers, 'num_layers': self.num_layers,
'mode': self.mode,
'is_test': not self.training 'is_test': not self.training
} }
outputs = { outputs = {
'Out': out, 'Out': out,
'LastH': last_h, 'State': state,
'LastC': last_c,
'Reserve': reserve, 'Reserve': reserve,
'StateOut': self._dropout_state, 'DropoutState': self._dropout_state,
} }
self._helper.append_op( self._helper.append_op(
type="cudnn_lstm", inputs=inputs, outputs=outputs, attrs=attrs) type="rnn", inputs=inputs, outputs=outputs, attrs=attrs)
out = paddle.tensor.transpose(out, out = paddle.tensor.transpose(out,
[1, 0, 2]) if not self.time_major else out [1, 0, 2]) if not self.time_major else out
states = (last_h, last_c) return out, tuple(state) if len(state) > 1 else state[0]
return out, states
def forward(self, inputs, initial_states=None, sequence_length=None): def forward(self, inputs, initial_states=None, sequence_length=None):
batch_index = 1 if self.time_major else 0 batch_index = 1 if self.time_major else 0

Loading…
Cancel
Save