|
|
|
@ -295,6 +295,12 @@ class TestModel(unittest.TestCase):
|
|
|
|
|
np.testing.assert_equal(output[0].shape[0], len(self.test_dataset))
|
|
|
|
|
fluid.disable_dygraph()
|
|
|
|
|
|
|
|
|
|
def test_summary_gpu(self):
|
|
|
|
|
paddle.disable_static(self.device)
|
|
|
|
|
rnn = paddle.nn.LSTM(16, 32, 2)
|
|
|
|
|
params_info = paddle.summary(
|
|
|
|
|
rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MyModel(paddle.nn.Layer):
|
|
|
|
|
def __init__(self):
|
|
|
|
@ -512,14 +518,33 @@ class TestModelFunction(unittest.TestCase):
|
|
|
|
|
model.summary(input_size=(20), dtype='float32')
|
|
|
|
|
|
|
|
|
|
def test_summary_nlp(self):
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
def _get_param_from_state_dict(state_dict):
|
|
|
|
|
params = 0
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
params += np.prod(v.numpy().shape)
|
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
nlp_net = paddle.nn.GRU(input_size=2,
|
|
|
|
|
hidden_size=3,
|
|
|
|
|
num_layers=3,
|
|
|
|
|
direction="bidirectional")
|
|
|
|
|
paddle.summary(nlp_net, (1, 1, 2))
|
|
|
|
|
|
|
|
|
|
rnn = paddle.nn.LSTM(16, 32, 2)
|
|
|
|
|
paddle.summary(rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
|
|
|
|
|
params_info = paddle.summary(
|
|
|
|
|
rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
|
|
|
|
|
gt_params = _get_param_from_state_dict(rnn.state_dict())
|
|
|
|
|
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
|
|
|
|
|
|
|
|
|
|
rnn = paddle.nn.GRU(16, 32, 2, direction='bidirectional')
|
|
|
|
|
params_info = paddle.summary(rnn, (4, 23, 16))
|
|
|
|
|
gt_params = _get_param_from_state_dict(rnn.state_dict())
|
|
|
|
|
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
|
|
|
|
|
|
|
|
|
|
rnn = paddle.nn.SimpleRNN(16, 32, 2, direction='bidirectional')
|
|
|
|
|
params_info = paddle.summary(rnn, (4, 23, 16))
|
|
|
|
|
gt_params = _get_param_from_state_dict(rnn.state_dict())
|
|
|
|
|
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
|
|
|
|
|
|
|
|
|
|
def test_summary_dtype(self):
|
|
|
|
|
input_shape = (3, 1)
|
|
|
|
|