|
|
|
|
@ -221,7 +221,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
|
|
|
|
|
if m_type in custom_ops:
|
|
|
|
|
flops_fn = custom_ops[m_type]
|
|
|
|
|
if m_type not in types_collection:
|
|
|
|
|
print("Customize Function has been appied to {}".format(m_type))
|
|
|
|
|
print("Customize Function has been applied to {}".format(
|
|
|
|
|
m_type))
|
|
|
|
|
elif m_type in register_hooks:
|
|
|
|
|
flops_fn = register_hooks[m_type]
|
|
|
|
|
if m_type not in types_collection:
|
|
|
|
|
@ -254,11 +255,9 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
|
|
|
|
|
for m in model.sublayers():
|
|
|
|
|
if len(list(m.children())) > 0:
|
|
|
|
|
continue
|
|
|
|
|
total_ops += m.total_ops
|
|
|
|
|
total_params += m.total_params
|
|
|
|
|
if hasattr(m, 'total_ops') and hasattr(m, 'total_params'):
|
|
|
|
|
total_ops = int(total_ops)
|
|
|
|
|
total_params = int(total_params)
|
|
|
|
|
if hasattr(m, 'total_ops') and hasattr(m, 'total_params'):
|
|
|
|
|
total_ops += m.total_ops
|
|
|
|
|
total_params += m.total_params
|
|
|
|
|
|
|
|
|
|
if training:
|
|
|
|
|
model.train()
|
|
|
|
|
@ -277,7 +276,8 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
|
|
|
|
|
for n, m in model.named_sublayers():
|
|
|
|
|
if len(list(m.children())) > 0:
|
|
|
|
|
continue
|
|
|
|
|
if "total_ops" in m._buffers:
|
|
|
|
|
if set(['total_ops', 'total_params', 'input_shape',
|
|
|
|
|
'output_shape']).issubset(set(list(m._buffers.keys()))):
|
|
|
|
|
table.add_row([
|
|
|
|
|
m.full_name(), list(m.input_shape.numpy()),
|
|
|
|
|
list(m.output_shape.numpy()), int(m.total_params),
|
|
|
|
|
@ -289,6 +289,6 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
|
|
|
|
|
m._buffers.pop('output_shape')
|
|
|
|
|
if (print_detail):
|
|
|
|
|
print(table)
|
|
|
|
|
print('Total Flops: {} Total Params: {}'.format(total_ops,
|
|
|
|
|
total_params))
|
|
|
|
|
return total_ops
|
|
|
|
|
print('Total Flops: {} Total Params: {}'.format(
|
|
|
|
|
int(total_ops), int(total_params)))
|
|
|
|
|
return int(total_ops)
|
|
|
|
|
|