|
|
|
@ -93,10 +93,14 @@ class TestSimpleRNN(unittest.TestCase):
|
|
|
|
|
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
|
|
|
|
|
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
|
|
|
|
|
|
|
|
|
|
def test_predict(self):
|
|
|
|
|
predict_test_util(self.place, "SimpleRNN")
|
|
|
|
|
|
|
|
|
|
def runTest(self):
|
|
|
|
|
self.test_with_initial_state()
|
|
|
|
|
self.test_with_zero_state()
|
|
|
|
|
self.test_with_input_lengths()
|
|
|
|
|
self.test_predict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGRU(unittest.TestCase):
|
|
|
|
@ -175,10 +179,14 @@ class TestGRU(unittest.TestCase):
|
|
|
|
|
np.testing.assert_allclose(y1, y2.numpy(), atol=1e-8, rtol=1e-5)
|
|
|
|
|
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
|
|
|
|
|
|
|
|
|
|
def test_predict(self):
|
|
|
|
|
predict_test_util(self.place, "GRU")
|
|
|
|
|
|
|
|
|
|
def runTest(self):
|
|
|
|
|
self.test_with_initial_state()
|
|
|
|
|
self.test_with_zero_state()
|
|
|
|
|
self.test_with_input_lengths()
|
|
|
|
|
self.test_predict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestLSTM(unittest.TestCase):
|
|
|
|
@ -258,18 +266,31 @@ class TestLSTM(unittest.TestCase):
|
|
|
|
|
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
|
|
|
|
|
|
|
|
|
|
def test_predict(self):
|
|
|
|
|
place = paddle.set_device(self.place)
|
|
|
|
|
predict_test_util(self.place, "LSTM")
|
|
|
|
|
|
|
|
|
|
def runTest(self):
|
|
|
|
|
self.test_with_initial_state()
|
|
|
|
|
self.test_with_zero_state()
|
|
|
|
|
self.test_with_input_lengths()
|
|
|
|
|
self.test_predict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_test_util(place, mode):
|
|
|
|
|
place = paddle.set_device(place)
|
|
|
|
|
paddle.seed(123)
|
|
|
|
|
np.random.seed(123)
|
|
|
|
|
|
|
|
|
|
class Net(paddle.nn.Layer):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(Net, self).__init__()
|
|
|
|
|
self.rnn1 = paddle.nn.LSTM(
|
|
|
|
|
16, 32, 2, direction="bidirectional", dropout=0.1)
|
|
|
|
|
self.rnn = getattr(paddle.nn, mode)(16,
|
|
|
|
|
32,
|
|
|
|
|
2,
|
|
|
|
|
direction="bidirectional",
|
|
|
|
|
dropout=0.1)
|
|
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
|
return self.rnn1(input)
|
|
|
|
|
return self.rnn(input)
|
|
|
|
|
|
|
|
|
|
x = paddle.randn((4, 10, 16))
|
|
|
|
|
x.stop_gradient = False
|
|
|
|
@ -277,7 +298,7 @@ class TestLSTM(unittest.TestCase):
|
|
|
|
|
mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype)
|
|
|
|
|
mask = paddle.unsqueeze(mask, [2])
|
|
|
|
|
rnn = Net()
|
|
|
|
|
y, (h, c) = rnn(x)
|
|
|
|
|
y, _ = rnn(x)
|
|
|
|
|
y = y * mask
|
|
|
|
|
loss = paddle.mean(y)
|
|
|
|
|
loss.backward()
|
|
|
|
@ -285,16 +306,15 @@ class TestLSTM(unittest.TestCase):
|
|
|
|
|
learning_rate=0.1, parameters=rnn.parameters())
|
|
|
|
|
optimizer.step()
|
|
|
|
|
rnn.eval()
|
|
|
|
|
y, (h, c) = rnn(x)
|
|
|
|
|
y, _ = rnn(x)
|
|
|
|
|
# `jit.to_static` would include a train_program, eval mode might cause
|
|
|
|
|
# some errors currently, such as dropout grad op gets `is_test == True`.
|
|
|
|
|
rnn.train()
|
|
|
|
|
|
|
|
|
|
rnn = paddle.jit.to_static(
|
|
|
|
|
rnn,
|
|
|
|
|
[paddle.static.InputSpec(
|
|
|
|
|
rnn, [paddle.static.InputSpec(
|
|
|
|
|
shape=[None, None, 16], dtype=x.dtype)])
|
|
|
|
|
paddle.jit.save(rnn, "./inference/lstm_infer")
|
|
|
|
|
paddle.jit.save(rnn, "./inference/%s_infer" % mode)
|
|
|
|
|
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
|
|
|
|
@ -305,8 +325,8 @@ class TestLSTM(unittest.TestCase):
|
|
|
|
|
fetch_targets] = paddle.static.load_inference_model(
|
|
|
|
|
dirname="./inference",
|
|
|
|
|
executor=exe,
|
|
|
|
|
model_filename="lstm_infer.pdmodel",
|
|
|
|
|
params_filename="lstm_infer.pdiparams")
|
|
|
|
|
model_filename="%s_infer.pdmodel" % mode,
|
|
|
|
|
params_filename="%s_infer.pdiparams" % mode)
|
|
|
|
|
results = exe.run(inference_program,
|
|
|
|
|
feed={feed_target_names[0]: x.numpy()},
|
|
|
|
|
fetch_list=fetch_targets)
|
|
|
|
@ -314,12 +334,6 @@ class TestLSTM(unittest.TestCase):
|
|
|
|
|
y.numpy(), results[0]) # eval results equal predict results
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
|
|
|
|
|
def runTest(self):
|
|
|
|
|
self.test_with_initial_state()
|
|
|
|
|
self.test_with_zero_state()
|
|
|
|
|
self.test_with_input_lengths()
|
|
|
|
|
self.test_predict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_tests(loader, tests, pattern):
|
|
|
|
|
suite = unittest.TestSuite()
|
|
|
|
|