|
|
@ -169,13 +169,7 @@ def count_element_op(op):
|
|
|
|
def _graph_flops(graph, detail=False):
|
|
|
|
def _graph_flops(graph, detail=False):
|
|
|
|
assert isinstance(graph, GraphWrapper)
|
|
|
|
assert isinstance(graph, GraphWrapper)
|
|
|
|
flops = 0
|
|
|
|
flops = 0
|
|
|
|
try:
|
|
|
|
table = Table(["OP Type", 'Param name', "Flops"])
|
|
|
|
from prettytable import PrettyTable
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
|
|
raise ImportError(
|
|
|
|
|
|
|
|
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
table = PrettyTable(["OP Type", 'Param name', "Flops"])
|
|
|
|
|
|
|
|
for op in graph.ops():
|
|
|
|
for op in graph.ops():
|
|
|
|
param_name = ''
|
|
|
|
param_name = ''
|
|
|
|
if op.type() in ['conv2d', 'depthwise_conv2d']:
|
|
|
|
if op.type() in ['conv2d', 'depthwise_conv2d']:
|
|
|
@ -200,10 +194,55 @@ def _graph_flops(graph, detail=False):
|
|
|
|
table.add_row([op.type(), param_name, op_flops])
|
|
|
|
table.add_row([op.type(), param_name, op_flops])
|
|
|
|
op_flops = 0
|
|
|
|
op_flops = 0
|
|
|
|
if detail:
|
|
|
|
if detail:
|
|
|
|
print(table)
|
|
|
|
table.print_table()
|
|
|
|
return flops
|
|
|
|
return flops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def static_flops(program, print_detail=False):
|
|
|
|
def static_flops(program, print_detail=False):
|
|
|
|
graph = GraphWrapper(program)
|
|
|
|
graph = GraphWrapper(program)
|
|
|
|
return _graph_flops(graph, detail=print_detail)
|
|
|
|
return _graph_flops(graph, detail=print_detail)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Table(object):
|
|
|
|
|
|
|
|
def __init__(self, table_heads):
|
|
|
|
|
|
|
|
self.table_heads = table_heads
|
|
|
|
|
|
|
|
self.table_len = []
|
|
|
|
|
|
|
|
self.data = []
|
|
|
|
|
|
|
|
self.col_num = len(table_heads)
|
|
|
|
|
|
|
|
for head in table_heads:
|
|
|
|
|
|
|
|
self.table_len.append(len(head))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_row(self, row_str):
|
|
|
|
|
|
|
|
if not isinstance(row_str, list):
|
|
|
|
|
|
|
|
print('The row_str should be a list')
|
|
|
|
|
|
|
|
if len(row_str) != self.col_num:
|
|
|
|
|
|
|
|
print(
|
|
|
|
|
|
|
|
'The length of row data should be equal the length of table heads, but the data: {} is not equal table heads {}'.
|
|
|
|
|
|
|
|
format(len(row_str), self.col_num))
|
|
|
|
|
|
|
|
for i in range(self.col_num):
|
|
|
|
|
|
|
|
if len(str(row_str[i])) > self.table_len[i]:
|
|
|
|
|
|
|
|
self.table_len[i] = len(str(row_str[i]))
|
|
|
|
|
|
|
|
self.data.append(row_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_row(self, row):
|
|
|
|
|
|
|
|
string = ''
|
|
|
|
|
|
|
|
for i in range(self.col_num):
|
|
|
|
|
|
|
|
string += '|' + str(row[i]).center(self.table_len[i] + 2)
|
|
|
|
|
|
|
|
string += '|'
|
|
|
|
|
|
|
|
print(string)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_shelf(self):
|
|
|
|
|
|
|
|
string = ''
|
|
|
|
|
|
|
|
for length in self.table_len:
|
|
|
|
|
|
|
|
string += '+'
|
|
|
|
|
|
|
|
string += '-' * (length + 2)
|
|
|
|
|
|
|
|
string += '+'
|
|
|
|
|
|
|
|
print(string)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_table(self):
|
|
|
|
|
|
|
|
self.print_shelf()
|
|
|
|
|
|
|
|
self.print_row(self.table_heads)
|
|
|
|
|
|
|
|
self.print_shelf()
|
|
|
|
|
|
|
|
for data in self.data:
|
|
|
|
|
|
|
|
self.print_row(data)
|
|
|
|
|
|
|
|
self.print_shelf()
|
|
|
|