add hapi api flops (#28755)

* add hapi api flops

* fix bug

* fix some bug

* add unit test

* fix unit test

* solve ci coverage

* fix doc

* fix doc

* fix static flops

* delete the comment

* fix some grammar problem in doc

* fix some bug

* fix some doc

* fix some doc
musl/disable_test_yolov3_temporarily
yukavio 4 years ago committed by GitHub
parent db85f4cf8f
commit 63e90ee331
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -275,6 +275,7 @@ from . import onnx
from .hapi import Model from .hapi import Model
from .hapi import callbacks from .hapi import callbacks
from .hapi import summary from .hapi import summary
from .hapi import flops
import paddle.text import paddle.text
import paddle.vision import paddle.vision

@ -13,13 +13,15 @@
# limitations under the License. # limitations under the License.
from . import logger from . import logger
from . import callbacks from . import callbacks #DEFINE_ALIAS
from . import model_summary from . import model_summary
from . import model from . import model
from .model import * from .model import *
from .model_summary import summary from .model_summary import summary #DEFINE_ALIAS
from .dynamic_flops import flops #DEFINE_ALIAS
logger.setup_logger() logger.setup_logger()
__all__ = ['callbacks'] + model.__all__ + ['summary'] __all__ = ['callbacks'] + model.__all__ + ['summary']
__all__ = model.__all__ + ['flops']

File diff suppressed because it is too large Load Diff

@ -0,0 +1,204 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import paddle
from prettytable import PrettyTable
from collections import OrderedDict
from paddle.static import Program, program_guard, Variable
class VarWrapper(object):
def __init__(self, var, graph):
assert isinstance(var, Variable)
assert isinstance(graph, GraphWrapper)
self._var = var
self._graph = graph
def name(self):
"""
Get the name of the variable.
"""
return self._var.name
def shape(self):
"""
Get the shape of the varibale.
"""
return self._var.shape
class OpWrapper(object):
def __init__(self, op, graph):
assert isinstance(graph, GraphWrapper)
self._op = op
self._graph = graph
def type(self):
"""
Get the type of this operator.
"""
return self._op.type
def inputs(self, name):
"""
Get all the varibales by the input name.
"""
if name in self._op.input_names:
return [
self._graph.var(var_name) for var_name in self._op.input(name)
]
else:
return []
def outputs(self, name):
"""
Get all the varibales by the output name.
"""
return [self._graph.var(var_name) for var_name in self._op.output(name)]
class GraphWrapper(object):
"""
It is a wrapper of paddle.fluid.framework.IrGraph with some special functions
for paddle slim framework.
Args:
program(framework.Program): A program with
in_nodes(dict): A dict to indicate the input nodes of the graph.
The key is user-defined and human-readable name.
The value is the name of Variable.
out_nodes(dict): A dict to indicate the input nodes of the graph.
The key is user-defined and human-readable name.
The value is the name of Variable.
"""
def __init__(self, program=None, in_nodes=[], out_nodes=[]):
"""
"""
super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program
self.persistables = {}
self.teacher_persistables = {}
for var in self.program.list_vars():
if var.persistable:
self.persistables[var.name] = var
self.compiled_graph = None
in_nodes = [] if in_nodes is None else in_nodes
out_nodes = [] if out_nodes is None else out_nodes
self.in_nodes = OrderedDict(in_nodes)
self.out_nodes = OrderedDict(out_nodes)
self._attrs = OrderedDict()
def ops(self):
"""
Return all operator nodes included in the graph as a set.
"""
ops = []
for block in self.program.blocks:
for op in block.ops:
ops.append(OpWrapper(op, self))
return ops
def var(self, name):
"""
Get the variable by variable name.
"""
for block in self.program.blocks:
if block.has_var(name):
return VarWrapper(block.var(name), self)
return None
def count_convNd(op):
filter_shape = op.inputs("Filter")[0].shape()
filter_ops = np.product(filter_shape[1:])
bias_ops = 1 if len(op.inputs("Bias")) > 0 else 0
output_numel = np.product(op.outputs("Output")[0].shape()[1:])
total_ops = output_numel * (filter_ops + bias_ops)
return total_ops
def count_leaky_relu(op):
total_ops = np.product(op.outputs("Output")[0].shape()[1:])
return total_ops
def count_bn(op):
output_numel = np.product(op.outputs("Y")[0].shape()[1:])
total_ops = 2 * output_numel
return total_ops
def count_linear(op):
total_mul = op.inputs("Y")[0].shape()[0]
numel = np.product(op.outputs("Out")[0].shape()[1:])
total_ops = total_mul * numel
return total_ops
def count_pool2d(op):
input_shape = op.inputs("X")[0].shape()
output_shape = op.outputs('Out')[0].shape()
kernel = np.array(input_shape[2:]) // np.array(output_shape[2:])
total_add = np.product(kernel)
total_div = 1
kernel_ops = total_add + total_div
num_elements = np.product(output_shape[1:])
total_ops = kernel_ops * num_elements
return total_ops
def count_element_op(op):
input_shape = op.inputs("X")[0].shape()
total_ops = np.product(input_shape[1:])
return total_ops
def _graph_flops(graph, detail=False):
assert isinstance(graph, GraphWrapper)
flops = 0
table = PrettyTable(["OP Type", 'Param name', "Flops"])
for op in graph.ops():
param_name = ''
if op.type() in ['conv2d', 'depthwise_conv2d']:
op_flops = count_convNd(op)
flops += op_flops
param_name = op.inputs("Filter")[0].name()
elif op.type() == 'pool2d':
op_flops = count_pool2d(op)
flops += op_flops
elif op.type() in ['mul', 'matmul']:
op_flops = count_linear(op)
flops += op_flops
param_name = op.inputs("Y")[0].name()
elif op.type() == 'batch_norm':
op_flops = count_bn(op)
flops += op_flops
elif op.type().startswith('element'):
op_flops = count_element_op(op)
flops += op_flops
if op_flops != 0:
table.add_row([op.type(), param_name, op_flops])
op_flops = 0
if detail:
print(table)
return flops
def static_flops(program, print_detail=False):
graph = GraphWrapper(program)
return _graph_flops(graph, detail=print_detail)

@ -33,6 +33,8 @@ from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.metric import Accuracy from paddle.metric import Accuracy
from paddle.vision.datasets import MNIST from paddle.vision.datasets import MNIST
from paddle.vision.models import LeNet from paddle.vision.models import LeNet
import paddle.vision.models as models
import paddle.fluid.dygraph.jit as jit
from paddle.io import DistributedBatchSampler, Dataset from paddle.io import DistributedBatchSampler, Dataset
from paddle.hapi.model import prepare_distributed_context from paddle.hapi.model import prepare_distributed_context
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
@ -546,6 +548,24 @@ class TestModelFunction(unittest.TestCase):
gt_params = _get_param_from_state_dict(rnn.state_dict()) gt_params = _get_param_from_state_dict(rnn.state_dict())
np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0) np.testing.assert_allclose(params_info['total_params'], gt_params / 2.0)
def test_static_flops(self):
paddle.disable_static()
net = models.__dict__['mobilenet_v2'](pretrained=False)
inputs = paddle.randn([1, 3, 224, 224])
static_program = jit._trace(net, inputs=[inputs])[1]
paddle.flops(static_program, [1, 3, 224, 224], print_detail=True)
def test_dynamic_flops(self):
net = models.__dict__['mobilenet_v2'](pretrained=False)
def customize_dropout(m, x, y):
m.total_ops += 0
paddle.flops(
net, [1, 3, 224, 224],
custom_ops={paddle.nn.Dropout: customize_dropout},
print_detail=True)
def test_summary_dtype(self): def test_summary_dtype(self):
input_shape = (3, 1) input_shape = (3, 1)
net = paddle.nn.Embedding(10, 3, sparse=True) net = paddle.nn.Embedding(10, 3, sparse=True)

Loading…
Cancel
Save