|
|
|
@ -229,7 +229,7 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
|
|
|
|
|
else:
|
|
|
|
|
if m_type not in types_collection:
|
|
|
|
|
print(
|
|
|
|
|
"Cannot find suitable count function for {}. Treat it as zero Macs.".
|
|
|
|
|
"Cannot find suitable count function for {}. Treat it as zero FLOPs.".
|
|
|
|
|
format(m_type))
|
|
|
|
|
|
|
|
|
|
if flops_fn is not None:
|
|
|
|
@ -256,9 +256,9 @@ def dynamic_flops(model, inputs, custom_ops=None, print_detail=False):
|
|
|
|
|
continue
|
|
|
|
|
total_ops += m.total_ops
|
|
|
|
|
total_params += m.total_params
|
|
|
|
|
|
|
|
|
|
total_ops = int(total_ops)
|
|
|
|
|
total_params = int(total_params)
|
|
|
|
|
if hasattr(m, 'total_ops') and hasattr(m, 'total_params'):
|
|
|
|
|
total_ops = int(total_ops)
|
|
|
|
|
total_params = int(total_params)
|
|
|
|
|
|
|
|
|
|
if training:
|
|
|
|
|
model.train()
|
|
|
|
|