|
|
@ -106,6 +106,12 @@ def summary(net, input_size, dtypes=None):
|
|
|
|
warnings.warn(
|
|
|
|
warnings.warn(
|
|
|
|
"Your model was created in static mode, this may not get correct summary information!"
|
|
|
|
"Your model was created in static mode, this may not get correct summary information!"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
in_train_mode = False
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
in_train_mode = net.training
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if in_train_mode:
|
|
|
|
|
|
|
|
net.eval()
|
|
|
|
|
|
|
|
|
|
|
|
def _is_shape(shape):
|
|
|
|
def _is_shape(shape):
|
|
|
|
for item in shape:
|
|
|
|
for item in shape:
|
|
|
@ -143,9 +149,13 @@ def summary(net, input_size, dtypes=None):
|
|
|
|
result, params_info = summary_string(net, _input_size, dtypes)
|
|
|
|
result, params_info = summary_string(net, _input_size, dtypes)
|
|
|
|
print(result)
|
|
|
|
print(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if in_train_mode:
|
|
|
|
|
|
|
|
net.train()
|
|
|
|
|
|
|
|
|
|
|
|
return params_info
|
|
|
|
return params_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
def summary_string(model, input_size, dtypes=None):
|
|
|
|
def summary_string(model, input_size, dtypes=None):
|
|
|
|
def _all_is_numper(items):
|
|
|
|
def _all_is_numper(items):
|
|
|
|
for item in items:
|
|
|
|
for item in items:
|
|
|
|