|
|
|
@ -18,12 +18,13 @@ from ..... import compat as cpt
|
|
|
|
|
from .... import core
|
|
|
|
|
from ....framework import IrGraph
|
|
|
|
|
from ....framework import IrNode
|
|
|
|
|
from ....framework import Operator
|
|
|
|
|
from .... import unique_name
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
|
|
|
|
|
'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass',
|
|
|
|
|
'AddQuantDequantPass'
|
|
|
|
|
'TransformForMobilePass', 'OutScaleForTrainingPass',
|
|
|
|
|
'OutScaleForInferencePass', 'AddQuantDequantPass'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
_fake_quant_op_list = [
|
|
|
|
@ -40,9 +41,9 @@ _fake_quant_dequant_op_list = [
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
_out_scale_op_list = [
|
|
|
|
|
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
|
|
|
|
|
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
|
|
|
|
|
"dropout", "split", "prelu", "conv2d_transpose", "leaky_relu"
|
|
|
|
|
"conv2d", "depthwise_conv2d", "mul", "matmul", "relu", "leaky_relu",
|
|
|
|
|
"relu6", "sigmoid", "tanh", "prelu", "swish", "softmax", "batch_norm",
|
|
|
|
|
"elementwise_add", "pool2d", "reshape2", "transpose2"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# list op real input and output names, to avoid processing input such as AxisTensor.
|
|
|
|
@ -67,6 +68,7 @@ _op_real_in_out_name = {
|
|
|
|
|
"not_equal": [["X", "Y"], ["Out"]],
|
|
|
|
|
"reshape": [["X"], ["Out"]],
|
|
|
|
|
"reshape2": [["X"], ["Out"]],
|
|
|
|
|
"transpose2": [["X"], ["Out"]],
|
|
|
|
|
"bilinear_interp": [["X"], ["Out"]],
|
|
|
|
|
"nearest_interp": [["X"], ["Out"]],
|
|
|
|
|
"trilinear_interp": [["X"], ["Out"]],
|
|
|
|
@ -76,11 +78,49 @@ _op_real_in_out_name = {
|
|
|
|
|
"relu": [["X"], ["Out"]],
|
|
|
|
|
"relu6": [["X"], ["Out"]],
|
|
|
|
|
"leaky_relu": [["X"], ["Out"]],
|
|
|
|
|
"prelu": [["X"], ["Out"]],
|
|
|
|
|
"tanh": [["X"], ["Out"]],
|
|
|
|
|
"swish": [["X"], ["Out"]],
|
|
|
|
|
"dropout": [["X"], ["Out"]],
|
|
|
|
|
"batch_norm": [["X"], ["Y"]],
|
|
|
|
|
"sigmoid": [["X"], ["Y"]],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_op_input_var_names(op):
|
|
|
|
|
""" """
|
|
|
|
|
assert isinstance(op, (IrNode, Operator)), \
|
|
|
|
|
"The input op should be IrNode or Operator."
|
|
|
|
|
var_names = []
|
|
|
|
|
op_name = op.name() if isinstance(op, IrNode) \
|
|
|
|
|
else op.type
|
|
|
|
|
name_list = _op_real_in_out_name[op_name][0]
|
|
|
|
|
for name in name_list:
|
|
|
|
|
var_name = op.input(name)
|
|
|
|
|
if isinstance(var_name, list):
|
|
|
|
|
var_names.extend(var_name)
|
|
|
|
|
else:
|
|
|
|
|
var_names.append(var_name)
|
|
|
|
|
return var_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_op_output_var_names(op):
|
|
|
|
|
""" """
|
|
|
|
|
assert isinstance(op, (IrNode, Operator)), \
|
|
|
|
|
"The input op should be IrNode or Operator."
|
|
|
|
|
var_names = []
|
|
|
|
|
op_name = op.name() if isinstance(op, IrNode) \
|
|
|
|
|
else op.type
|
|
|
|
|
name_list = _op_real_in_out_name[op_name][1]
|
|
|
|
|
for name in name_list:
|
|
|
|
|
var_name = op.output(name)
|
|
|
|
|
if isinstance(var_name, list):
|
|
|
|
|
var_names.extend(var_name)
|
|
|
|
|
else:
|
|
|
|
|
var_names.append(var_name)
|
|
|
|
|
return var_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_var_node(var_node, value, scope, place):
|
|
|
|
|
assert isinstance(value,
|
|
|
|
|
np.ndarray), 'The type of value should be numpy array.'
|
|
|
|
@ -97,17 +137,18 @@ def _is_input_all_not_persistable(graph, op_node):
|
|
|
|
|
Analyse the real inputs of the op node are all not persistable.
|
|
|
|
|
'''
|
|
|
|
|
is_input_all_not_persistable = True
|
|
|
|
|
op_node_name = op_node.name()
|
|
|
|
|
input_name_list = _op_real_in_out_name[op_node_name][0]
|
|
|
|
|
for input_name in input_name_list:
|
|
|
|
|
for arg_name in op_node.input(input_name):
|
|
|
|
|
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
|
|
|
|
|
is_input_all_not_persistable = (is_input_all_not_persistable and \
|
|
|
|
|
(not in_node.persistable()))
|
|
|
|
|
for var_name in _get_op_input_var_names(op_node):
|
|
|
|
|
in_node = graph._find_node_by_name(op_node.inputs, var_name)
|
|
|
|
|
is_input_all_not_persistable = (is_input_all_not_persistable and \
|
|
|
|
|
(not in_node.persistable()))
|
|
|
|
|
return is_input_all_not_persistable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QuantizationTransformPass(object):
|
|
|
|
|
"""
|
|
|
|
|
Quantize the ops that have weights. Add quant and dequant ops for the quantized
|
|
|
|
|
ops's inputs.
|
|
|
|
|
"""
|
|
|
|
|
_supported_quantizable_op_type = [
|
|
|
|
|
'conv2d', 'depthwise_conv2d', 'mul', 'matmul'
|
|
|
|
|
]
|
|
|
|
@ -124,8 +165,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
skip_pattern=['skip_quant'],
|
|
|
|
|
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
|
|
|
|
|
"""
|
|
|
|
|
Convert and rewrite the IrGraph according to weight and
|
|
|
|
|
activation quantization type.
|
|
|
|
|
Constructor.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
|
|
|
|
@ -1088,7 +1128,7 @@ class TransformForMobilePass(object):
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScaleForTrainingPass(object):
|
|
|
|
|
class OutScaleForTrainingPass(object):
|
|
|
|
|
def __init__(self, scope=None, place=None, moving_rate=0.9):
|
|
|
|
|
"""
|
|
|
|
|
This pass is used for calculating output scales of some operators.
|
|
|
|
@ -1195,7 +1235,7 @@ class ScaleForTrainingPass(object):
|
|
|
|
|
return "%s@scale" % (var_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ScaleForInferencePass(object):
|
|
|
|
|
class OutScaleForInferencePass(object):
|
|
|
|
|
def __init__(self, scope=None):
|
|
|
|
|
"""
|
|
|
|
|
This pass is used for setting output scales of some operators.
|
|
|
|
@ -1226,7 +1266,7 @@ class ScaleForInferencePass(object):
|
|
|
|
|
scale_name = self._scale_name(op_node.output_arg_names()[0])
|
|
|
|
|
scale_v = np.array(
|
|
|
|
|
self._scope.find_var(scale_name).get_tensor())[0]
|
|
|
|
|
op_node.op()._set_attr("out_scale", float(scale_v))
|
|
|
|
|
op_node.op()._set_attr("out_threshold", float(scale_v))
|
|
|
|
|
graph.resolve_hazard()
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
@ -1238,6 +1278,10 @@ class ScaleForInferencePass(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AddQuantDequantPass(object):
|
|
|
|
|
"""
|
|
|
|
|
Quantize the ops that do not have weights, and add quant_dequant op for the
|
|
|
|
|
quantized ops's inputs.
|
|
|
|
|
"""
|
|
|
|
|
_supported_quantizable_op_type = [
|
|
|
|
|
"pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose",
|
|
|
|
|
"equal", "gather", "greater_equal", "greater_than", "less_equal",
|
|
|
|
@ -1259,9 +1303,7 @@ class AddQuantDequantPass(object):
|
|
|
|
|
quantizable_op_type=["elementwise_add", "pool2d"],
|
|
|
|
|
is_full_quantized=False):
|
|
|
|
|
"""
|
|
|
|
|
This pass add quant_dequant op for some ops, of which all the inputs must be
|
|
|
|
|
not persistable.
|
|
|
|
|
The input scales can be obtained from the quant_dequant op.
|
|
|
|
|
Constructor.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
scope(fluid.Scope): The scope is used to initialize these new parameters.
|
|
|
|
@ -1338,10 +1380,7 @@ class AddQuantDequantPass(object):
|
|
|
|
|
op_node.op()._set_attr("quantization_type",
|
|
|
|
|
"qat_without_weight")
|
|
|
|
|
op_node.op()._set_attr("activation_bits", self._quant_bits)
|
|
|
|
|
input_name_list = _op_real_in_out_name[op_node.name()][0]
|
|
|
|
|
arg_names = []
|
|
|
|
|
for input_name in input_name_list:
|
|
|
|
|
arg_names.extend(op_node.input(input_name))
|
|
|
|
|
arg_names = _get_op_input_var_names(op_node)
|
|
|
|
|
for arg_name in arg_names:
|
|
|
|
|
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
|
|
|
|
|
if arg_name in dequantized_vars_map:
|
|
|
|
|