add model_stat.py (#16512)
* Add a tool to summary model's PARAMS, FLOPs in paddle/fluid/contrib.revert-16555-model_data_cryption_link_all_lib
parent
d4f63d82ac
commit
e18ab78f67
@ -0,0 +1,194 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
'''
|
||||
Example:
|
||||
>>from paddle.fluid.contrib.model_stat import summary
|
||||
>>main_program = ...
|
||||
>>summary(main_program)
|
||||
+-----+------------+----------------+----------------+---------+------------+
|
||||
| No. | TYPE | INPUT | OUTPUT | PARAMs | FLOPs |
|
||||
+-----+------------+----------------+----------------+---------+------------+
|
||||
| 0 | conv2d | (3, 200, 200) | (64, 100, 100) | 9408 | 188160000 |
|
||||
| 1 | batch_norm | (64, 100, 100) | (64, 100, 100) | 256 | 640000 |
|
||||
| 2 | relu | (64, 100, 100) | (64, 100, 100) | 0 | 640000 |
|
||||
| 3 | pool2d | (64, 100, 100) | (64, 50, 50) | 0 | 1440000 |
|
||||
...
|
||||
| 176 | conv2d | (512, 7, 7) | (512, 7, 7) | 2359296 | 231211008 |
|
||||
| 177 | relu | (512, 7, 7) | (512, 7, 7) | 0 | 25088 |
|
||||
| 178 | conv2d | (512, 7, 7) | (2048, 7, 7) | 1048576 | 102760448 |
|
||||
| 179 | relu | (2048, 7, 7) | (2048, 7, 7) | 0 | 100352 |
|
||||
| 180 | pool2d | (2048, 7, 7) | (2048, 1, 1) | 0 | 100352 |
|
||||
+-----+------------+----------------+----------------+---------+------------+
|
||||
Total PARAMs: 48017344(0.0480G)
|
||||
Total FLOPs: 11692747751(11.69G)
|
||||
'''
|
||||
from collections import OrderedDict
|
||||
from prettytable import PrettyTable
|
||||
|
||||
|
||||
def summary(main_prog):
|
||||
'''
|
||||
It can summary model's PARAMS, FLOPs until now.
|
||||
It support common operator like conv, fc, pool, relu, sigmoid, bn etc.
|
||||
Args:
|
||||
main_prog: main program
|
||||
Returns:
|
||||
print summary on terminal
|
||||
'''
|
||||
collected_ops_list = []
|
||||
for one_b in main_prog.blocks:
|
||||
block_vars = one_b.vars
|
||||
for one_op in one_b.ops:
|
||||
op_info = OrderedDict()
|
||||
spf_res = _summary_model(block_vars, one_op)
|
||||
if spf_res is None:
|
||||
continue
|
||||
# TODO: get the operator name
|
||||
op_info['type'] = one_op.type
|
||||
op_info['input_shape'] = spf_res[0][1:]
|
||||
op_info['out_shape'] = spf_res[1][1:]
|
||||
op_info['PARAMs'] = spf_res[2]
|
||||
op_info['FLOPs'] = spf_res[3]
|
||||
collected_ops_list.append(op_info)
|
||||
|
||||
summary_table, total = _format_summary(collected_ops_list)
|
||||
_print_summary(summary_table, total)
|
||||
|
||||
|
||||
def _summary_model(block_vars, one_op):
|
||||
'''
|
||||
Compute operator's params and flops.
|
||||
Args:
|
||||
block_vars: all vars of one block
|
||||
one_op: one operator to count
|
||||
Returns:
|
||||
in_data_shape: one operator's input data shape
|
||||
out_data_shape: one operator's output data shape
|
||||
params: one operator's PARAMs
|
||||
flops: : one operator's FLOPs
|
||||
'''
|
||||
if one_op.type in ['conv2d', 'depthwise_conv2d']:
|
||||
k_arg_shape = block_vars[one_op.input("Filter")[0]].shape
|
||||
in_data_shape = block_vars[one_op.input("Input")[0]].shape
|
||||
out_data_shape = block_vars[one_op.output("Output")[0]].shape
|
||||
c_out, c_in, k_h, k_w = k_arg_shape
|
||||
_, c_out_, h_out, w_out = out_data_shape
|
||||
assert c_out == c_out_, 'shape error!'
|
||||
k_groups = one_op.attr("groups")
|
||||
kernel_ops = k_h * k_w * (c_in / k_groups)
|
||||
bias_ops = 0 if one_op.input("Bias") == [] else 1
|
||||
params = c_out * (kernel_ops + bias_ops)
|
||||
flops = h_out * w_out * c_out * (kernel_ops + bias_ops)
|
||||
# base nvidia paper, include mul and add
|
||||
flops = 2 * flops
|
||||
|
||||
elif one_op.type == 'pool2d':
|
||||
in_data_shape = block_vars[one_op.input("X")[0]].shape
|
||||
out_data_shape = block_vars[one_op.output("Out")[0]].shape
|
||||
_, c_out, h_out, w_out = out_data_shape
|
||||
k_size = one_op.attr("ksize")
|
||||
params = 0
|
||||
flops = h_out * w_out * c_out * (k_size[0] * k_size[1])
|
||||
|
||||
elif one_op.type == 'mul':
|
||||
k_arg_shape = block_vars[one_op.input("Y")[0]].shape
|
||||
in_data_shape = block_vars[one_op.input("X")[0]].shape
|
||||
out_data_shape = block_vars[one_op.output("Out")[0]].shape
|
||||
# TODO: fc has mul ops
|
||||
# add attr to mul op, tell us whether it belongs to 'fc'
|
||||
# this's not the best way
|
||||
if 'fc' not in one_op.output("Out")[0]:
|
||||
return None
|
||||
k_in, k_out = k_arg_shape
|
||||
# bias in sum op
|
||||
params = k_in * k_out + 1
|
||||
flops = k_in * k_out
|
||||
|
||||
elif one_op.type in ['sigmoid', 'tanh', 'relu', 'leaky_relu', 'prelu']:
|
||||
in_data_shape = block_vars[one_op.input("X")[0]].shape
|
||||
out_data_shape = block_vars[one_op.output("Out")[0]].shape
|
||||
params = 0
|
||||
if one_op.type == 'prelu':
|
||||
params = 1
|
||||
flops = 1
|
||||
for one_dim in in_data_shape:
|
||||
flops *= one_dim
|
||||
|
||||
elif one_op.type == 'batch_norm':
|
||||
in_data_shape = block_vars[one_op.input("X")[0]].shape
|
||||
out_data_shape = block_vars[one_op.output("Y")[0]].shape
|
||||
_, c_in, h_out, w_out = in_data_shape
|
||||
# gamma, beta
|
||||
params = c_in * 2
|
||||
# compute mean and std
|
||||
flops = h_out * w_out * c_in * 2
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
return in_data_shape, out_data_shape, params, flops
|
||||
|
||||
|
||||
def _format_summary(collected_ops_list):
|
||||
'''
|
||||
Format summary report.
|
||||
Args:
|
||||
collected_ops_list: the collected operator with summary
|
||||
Returns:
|
||||
summary_table: summary report format
|
||||
total: sum param and flops
|
||||
'''
|
||||
summary_table = PrettyTable(
|
||||
["No.", "TYPE", "INPUT", "OUTPUT", "PARAMs", "FLOPs"])
|
||||
summary_table.align = 'r'
|
||||
|
||||
total = {}
|
||||
total_params = []
|
||||
total_flops = []
|
||||
for i, one_op in enumerate(collected_ops_list):
|
||||
# notice the order
|
||||
table_row = [
|
||||
i,
|
||||
one_op['type'],
|
||||
one_op['input_shape'],
|
||||
one_op['out_shape'],
|
||||
int(one_op['PARAMs']),
|
||||
int(one_op['FLOPs']),
|
||||
]
|
||||
summary_table.add_row(table_row)
|
||||
total_params.append(int(one_op['PARAMs']))
|
||||
total_flops.append(int(one_op['FLOPs']))
|
||||
|
||||
total['params'] = total_params
|
||||
total['flops'] = total_flops
|
||||
|
||||
return summary_table, total
|
||||
|
||||
|
||||
def _print_summary(summary_table, total):
|
||||
'''
|
||||
Print all the summary on terminal.
|
||||
Args:
|
||||
summary_table: summary report format
|
||||
total: sum param and flops
|
||||
'''
|
||||
parmas = total['params']
|
||||
flops = total['flops']
|
||||
print(summary_table)
|
||||
print('Total PARAMs: {}({:.4f}M)'.format(
|
||||
sum(parmas), sum(parmas) / (10**6)))
|
||||
print('Total FLOPs: {}({:.2f}G)'.format(sum(flops), sum(flops) / 10**9))
|
||||
print(
|
||||
"Notice: \n now supported ops include [Conv, DepthwiseConv, FC(mul), BatchNorm, Pool, Activation(sigmoid, tanh, relu, leaky_relu, prelu)]"
|
||||
)
|
Loading…
Reference in new issue