You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
6.3 KiB
210 lines
6.3 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import numpy as np
|
|
import paddle
|
|
from collections import OrderedDict
|
|
from paddle.static import Program, program_guard, Variable
|
|
|
|
|
|
class VarWrapper(object):
|
|
def __init__(self, var, graph):
|
|
assert isinstance(var, Variable)
|
|
assert isinstance(graph, GraphWrapper)
|
|
self._var = var
|
|
self._graph = graph
|
|
|
|
def name(self):
|
|
"""
|
|
Get the name of the variable.
|
|
"""
|
|
return self._var.name
|
|
|
|
def shape(self):
|
|
"""
|
|
Get the shape of the varibale.
|
|
"""
|
|
return self._var.shape
|
|
|
|
|
|
class OpWrapper(object):
|
|
def __init__(self, op, graph):
|
|
assert isinstance(graph, GraphWrapper)
|
|
self._op = op
|
|
self._graph = graph
|
|
|
|
def type(self):
|
|
"""
|
|
Get the type of this operator.
|
|
"""
|
|
return self._op.type
|
|
|
|
def inputs(self, name):
|
|
"""
|
|
Get all the varibales by the input name.
|
|
"""
|
|
if name in self._op.input_names:
|
|
return [
|
|
self._graph.var(var_name) for var_name in self._op.input(name)
|
|
]
|
|
else:
|
|
return []
|
|
|
|
def outputs(self, name):
|
|
"""
|
|
Get all the varibales by the output name.
|
|
"""
|
|
return [self._graph.var(var_name) for var_name in self._op.output(name)]
|
|
|
|
|
|
class GraphWrapper(object):
|
|
"""
|
|
It is a wrapper of paddle.fluid.framework.IrGraph with some special functions
|
|
for paddle slim framework.
|
|
|
|
Args:
|
|
program(framework.Program): A program with
|
|
in_nodes(dict): A dict to indicate the input nodes of the graph.
|
|
The key is user-defined and human-readable name.
|
|
The value is the name of Variable.
|
|
out_nodes(dict): A dict to indicate the input nodes of the graph.
|
|
The key is user-defined and human-readable name.
|
|
The value is the name of Variable.
|
|
"""
|
|
|
|
def __init__(self, program=None, in_nodes=[], out_nodes=[]):
|
|
"""
|
|
"""
|
|
super(GraphWrapper, self).__init__()
|
|
self.program = Program() if program is None else program
|
|
self.persistables = {}
|
|
self.teacher_persistables = {}
|
|
for var in self.program.list_vars():
|
|
if var.persistable:
|
|
self.persistables[var.name] = var
|
|
self.compiled_graph = None
|
|
in_nodes = [] if in_nodes is None else in_nodes
|
|
out_nodes = [] if out_nodes is None else out_nodes
|
|
self.in_nodes = OrderedDict(in_nodes)
|
|
self.out_nodes = OrderedDict(out_nodes)
|
|
self._attrs = OrderedDict()
|
|
|
|
def ops(self):
|
|
"""
|
|
Return all operator nodes included in the graph as a set.
|
|
"""
|
|
ops = []
|
|
for block in self.program.blocks:
|
|
for op in block.ops:
|
|
ops.append(OpWrapper(op, self))
|
|
return ops
|
|
|
|
def var(self, name):
|
|
"""
|
|
Get the variable by variable name.
|
|
"""
|
|
for block in self.program.blocks:
|
|
if block.has_var(name):
|
|
return VarWrapper(block.var(name), self)
|
|
return None
|
|
|
|
|
|
def count_convNd(op):
|
|
filter_shape = op.inputs("Filter")[0].shape()
|
|
filter_ops = np.product(filter_shape[1:])
|
|
bias_ops = 1 if len(op.inputs("Bias")) > 0 else 0
|
|
output_numel = np.product(op.outputs("Output")[0].shape()[1:])
|
|
total_ops = output_numel * (filter_ops + bias_ops)
|
|
return total_ops
|
|
|
|
|
|
def count_leaky_relu(op):
|
|
total_ops = np.product(op.outputs("Output")[0].shape()[1:])
|
|
return total_ops
|
|
|
|
|
|
def count_bn(op):
|
|
output_numel = np.product(op.outputs("Y")[0].shape()[1:])
|
|
total_ops = 2 * output_numel
|
|
return total_ops
|
|
|
|
|
|
def count_linear(op):
|
|
total_mul = op.inputs("Y")[0].shape()[0]
|
|
numel = np.product(op.outputs("Out")[0].shape()[1:])
|
|
total_ops = total_mul * numel
|
|
return total_ops
|
|
|
|
|
|
def count_pool2d(op):
|
|
input_shape = op.inputs("X")[0].shape()
|
|
output_shape = op.outputs('Out')[0].shape()
|
|
kernel = np.array(input_shape[2:]) // np.array(output_shape[2:])
|
|
total_add = np.product(kernel)
|
|
total_div = 1
|
|
kernel_ops = total_add + total_div
|
|
num_elements = np.product(output_shape[1:])
|
|
total_ops = kernel_ops * num_elements
|
|
return total_ops
|
|
|
|
|
|
def count_element_op(op):
|
|
input_shape = op.inputs("X")[0].shape()
|
|
total_ops = np.product(input_shape[1:])
|
|
return total_ops
|
|
|
|
|
|
def _graph_flops(graph, detail=False):
|
|
assert isinstance(graph, GraphWrapper)
|
|
flops = 0
|
|
try:
|
|
from prettytable import PrettyTable
|
|
except ImportError:
|
|
raise ImportError(
|
|
"paddle.flops() requires package `prettytable`, place install it firstly using `pip install prettytable`. "
|
|
)
|
|
table = PrettyTable(["OP Type", 'Param name', "Flops"])
|
|
for op in graph.ops():
|
|
param_name = ''
|
|
if op.type() in ['conv2d', 'depthwise_conv2d']:
|
|
op_flops = count_convNd(op)
|
|
flops += op_flops
|
|
param_name = op.inputs("Filter")[0].name()
|
|
elif op.type() == 'pool2d':
|
|
op_flops = count_pool2d(op)
|
|
flops += op_flops
|
|
|
|
elif op.type() in ['mul', 'matmul']:
|
|
op_flops = count_linear(op)
|
|
flops += op_flops
|
|
param_name = op.inputs("Y")[0].name()
|
|
elif op.type() == 'batch_norm':
|
|
op_flops = count_bn(op)
|
|
flops += op_flops
|
|
elif op.type().startswith('element'):
|
|
op_flops = count_element_op(op)
|
|
flops += op_flops
|
|
if op_flops != 0:
|
|
table.add_row([op.type(), param_name, op_flops])
|
|
op_flops = 0
|
|
if detail:
|
|
print(table)
|
|
return flops
|
|
|
|
|
|
def static_flops(program, print_detail=False):
|
|
graph = GraphWrapper(program)
|
|
return _graph_flops(graph, detail=print_detail)
|