Reproduce summary api (#27367)

* reproduce summary api
revert-27520-disable_pr
LielinJiang 4 years ago committed by GitHub
parent 29f1560d8f
commit 78a27a2b0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1813,7 +1813,7 @@ class Model(object):
return logs, outputs return logs, outputs
return logs return logs
def summary(self, input_size=None, batch_size=None, dtype=None): def summary(self, input_size=None, dtype=None):
"""Prints a string summary of the network. """Prints a string summary of the network.
Args: Args:
@ -1822,7 +1822,6 @@ class Model(object):
one input, input_size can be tuple or InputSpec. if model have multiple one input, input_size can be tuple or InputSpec. if model have multiple
input, input_size must be a list which contain every input's shape. input, input_size must be a list which contain every input's shape.
Default: None. Default: None.
batch_size (int, optional): batch size of input tensor, Default: None.
dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None. dtypes (str, optional): if dtypes is None, 'float32' will be used, Default: None.
Returns: Returns:
@ -1859,7 +1858,7 @@ class Model(object):
_input_size = input_size _input_size = input_size
else: else:
_input_size = self._inputs _input_size = self._inputs
return summary(self.network, _input_size, batch_size, dtype) return summary(self.network, _input_size, dtype)
def _verify_spec(self, specs, is_input=False): def _verify_spec(self, specs, is_input=False):
out_specs = [] out_specs = []

File diff suppressed because it is too large Load Diff

@ -494,17 +494,22 @@ class TestModelFunction(unittest.TestCase):
model.summary(input_size=(20)) model.summary(input_size=(20))
model.summary(input_size=[(20)]) model.summary(input_size=[(20)])
model.summary(input_size=(20), batch_size=2) model.summary(input_size=(20), dtype='float32')
def test_summary_nlp(self): def test_summary_nlp(self):
paddle.enable_static() paddle.enable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) nlp_net = paddle.nn.GRU(input_size=2,
paddle.summary(nlp_net, (1, 2)) hidden_size=3,
num_layers=3,
direction="bidirectional")
paddle.summary(nlp_net, (1, 1, 2))
rnn = paddle.nn.LSTM(16, 32, 2)
paddle.summary(rnn, [(-1, 23, 16), ((2, None, 32), (2, -1, 32))])
def test_summary_error(self): def test_summary_error(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, '2')) paddle.summary(nlp_net, (1, 1, '2'))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
@ -512,7 +517,7 @@ class TestModelFunction(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3) nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 2)) paddle.summary(nlp_net, (1, 1, 2))
def test_export_deploy_model(self): def test_export_deploy_model(self):
for dynamic in [True, False]: for dynamic in [True, False]:

Loading…
Cancel
Save