|
|
|
@ -26,14 +26,31 @@ __all__ = [
|
|
|
|
|
'AddQuantDequantPass'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul', 'pool2d']
|
|
|
|
|
|
|
|
|
|
_fake_quant_op_list = [
|
|
|
|
|
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
|
|
|
|
|
'fake_quantize_moving_average_abs_max', 'fake_channel_wise_quantize_abs_max'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
_fake_dequant_op_list = [
|
|
|
|
|
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
_out_scale_op_list = [
|
|
|
|
|
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
|
|
|
|
|
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
|
|
|
|
|
"dropout", "split", "prelu", "conv2d_transpose", "leaky_relu"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_var_node(var_node, value, scope, place):
|
|
|
|
|
assert isinstance(value,
|
|
|
|
|
np.ndarray), 'The type of value should be numpy array.'
|
|
|
|
|
assert scope is not None, \
|
|
|
|
|
'The scope cannot be set None.'
|
|
|
|
|
'The scope cannot be set None.'
|
|
|
|
|
assert place is not None, \
|
|
|
|
|
'The place cannot be set None.'
|
|
|
|
|
'The place cannot be set None.'
|
|
|
|
|
tensor = scope.var(var_node.name()).get_tensor()
|
|
|
|
|
tensor.set(value, place)
|
|
|
|
|
|
|
|
|
@ -47,7 +64,8 @@ class QuantizationTransformPass(object):
|
|
|
|
|
activation_quantize_type='abs_max',
|
|
|
|
|
weight_quantize_type='abs_max',
|
|
|
|
|
window_size=10000,
|
|
|
|
|
moving_rate=0.9):
|
|
|
|
|
moving_rate=0.9,
|
|
|
|
|
skip_pattern='skip_quant'):
|
|
|
|
|
"""
|
|
|
|
|
Convert and rewrite the IrGraph according to weight and
|
|
|
|
|
activation quantization type.
|
|
|
|
@ -92,6 +110,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
self._place = place
|
|
|
|
|
self._weight_bits = weight_bits
|
|
|
|
|
self._activation_bits = activation_bits
|
|
|
|
|
self._skip_pattern = skip_pattern
|
|
|
|
|
|
|
|
|
|
quant_type = [
|
|
|
|
|
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
|
|
|
|
@ -114,7 +133,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
self._window_size = window_size
|
|
|
|
|
self._moving_rate = moving_rate
|
|
|
|
|
|
|
|
|
|
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
|
self._quantizable_ops = _quantizable_op_list
|
|
|
|
|
self._conv_ops = ['conv2d', 'depthwise_conv2d']
|
|
|
|
|
self._quantizable_grad_ops = [
|
|
|
|
|
'%s_grad' % (op) for op in self._quantizable_ops
|
|
|
|
@ -138,6 +157,16 @@ class QuantizationTransformPass(object):
|
|
|
|
|
dequantized_vars = collections.OrderedDict()
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
|
|
|
|
|
|
|
|
|
def _quant_preprocess(op_node):
|
|
|
|
|
pool_skipped = op_node.op().has_attr("pooling_type") and \
|
|
|
|
|
op_node.op().attr("pooling_type") == 'avg'
|
|
|
|
|
user_skipped = isinstance(self._skip_pattern, str) and \
|
|
|
|
|
op_node.op().has_attr("op_namescope") and \
|
|
|
|
|
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
|
|
|
|
|
|
|
|
|
|
if pool_skipped or user_skipped:
|
|
|
|
|
op_node.op()._set_attr("skip_quant", True)
|
|
|
|
|
|
|
|
|
|
def _transform_forward(graph, op):
|
|
|
|
|
for var_node in op.inputs:
|
|
|
|
|
if var_node.name() not in op.input_arg_names():
|
|
|
|
@ -188,14 +217,28 @@ class QuantizationTransformPass(object):
|
|
|
|
|
if not self._is_test:
|
|
|
|
|
self._create_global_step(graph)
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
# Do the preproccess of quantization, such as skipping some ops
|
|
|
|
|
# for not being quantized.
|
|
|
|
|
for op in ops:
|
|
|
|
|
if op.name() in self._quantizable_ops or \
|
|
|
|
|
op.name() in self._quantizable_grad_ops:
|
|
|
|
|
_quant_preprocess(op)
|
|
|
|
|
# The process of _transform_forward and _transform_backward is needed in two for loops.
|
|
|
|
|
# The loop for transforming the forward graph:
|
|
|
|
|
for op in ops:
|
|
|
|
|
if op.name() in self._quantizable_ops:
|
|
|
|
|
skipped = op.op().has_attr("skip_quant") and \
|
|
|
|
|
op.op().attr("skip_quant")
|
|
|
|
|
if skipped:
|
|
|
|
|
continue
|
|
|
|
|
_transform_forward(graph, op)
|
|
|
|
|
# The loop for renaming the inputs of backward op.
|
|
|
|
|
for op in ops:
|
|
|
|
|
if op.name() in self._quantizable_grad_ops:
|
|
|
|
|
skipped = op.op().has_attr("skip_quant") and \
|
|
|
|
|
op.op().attr("skip_quant")
|
|
|
|
|
if skipped:
|
|
|
|
|
continue
|
|
|
|
|
_transform_backward(graph, op)
|
|
|
|
|
graph.resolve_hazard()
|
|
|
|
|
return graph
|
|
|
|
@ -571,16 +614,10 @@ class QuantizationFreezePass(object):
|
|
|
|
|
self._weight_bits = weight_bits
|
|
|
|
|
self._activation_bits = activation_bits
|
|
|
|
|
self._weight_quantize_type = weight_quantize_type
|
|
|
|
|
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
|
self._quantizable_ops = _quantizable_op_list
|
|
|
|
|
self._conv_ops = ['conv2d', 'depthwise_conv2d']
|
|
|
|
|
self._fake_quant_op_names = [
|
|
|
|
|
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
|
|
|
|
|
'fake_quantize_moving_average_abs_max',
|
|
|
|
|
'fake_channel_wise_quantize_abs_max'
|
|
|
|
|
]
|
|
|
|
|
self._fake_dequant_op_names = [
|
|
|
|
|
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
|
|
|
|
|
]
|
|
|
|
|
self._fake_quant_op_names = _fake_quant_op_list
|
|
|
|
|
self._fake_dequant_op_names = _fake_dequant_op_list
|
|
|
|
|
self._op_input_rename_map = collections.OrderedDict()
|
|
|
|
|
self._op_output_rename_map = collections.OrderedDict()
|
|
|
|
|
self._var_scale_map = collections.OrderedDict()
|
|
|
|
@ -635,6 +672,10 @@ class QuantizationFreezePass(object):
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
|
if op_name in self._quantizable_ops:
|
|
|
|
|
skipped = op_node.op().has_attr("skip_quant") and \
|
|
|
|
|
op_node.op().attr("skip_quant")
|
|
|
|
|
if skipped:
|
|
|
|
|
continue
|
|
|
|
|
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
|
|
|
|
|
self._insert_post_channel_dequant_op(graph, op_node)
|
|
|
|
|
else:
|
|
|
|
@ -727,6 +768,13 @@ class QuantizationFreezePass(object):
|
|
|
|
|
|
|
|
|
|
def _insert_post_dequant_op(self, graph, op_node):
|
|
|
|
|
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
|
|
|
|
|
if len(op_node.input_arg_names()) >= 2 and len(persistable_vars) == 0:
|
|
|
|
|
raise ValueError("The op %s has more than one inputs "
|
|
|
|
|
"and all of them are not persistable. "
|
|
|
|
|
"Now, it is not supported!" % (op_node.name()))
|
|
|
|
|
max_range = 1
|
|
|
|
|
param_range = (1 << (self._weight_bits - 1)) - 1
|
|
|
|
|
act_range = (1 << (self._activation_bits - 1)) - 1
|
|
|
|
|
for var_node in op_node.inputs:
|
|
|
|
|
name = var_node.name()
|
|
|
|
|
if name not in op_node.input_arg_names():
|
|
|
|
@ -739,13 +787,12 @@ class QuantizationFreezePass(object):
|
|
|
|
|
original_var_name = self._original_var_name(name)
|
|
|
|
|
scale_v = self._var_scale_map[original_var_name]
|
|
|
|
|
if original_var_name in persistable_vars:
|
|
|
|
|
param_range = (1 << (self._weight_bits - 1)) - 1
|
|
|
|
|
act_range = (1 << (self._activation_bits - 1)) - 1
|
|
|
|
|
assert self._is_float(
|
|
|
|
|
scale_v), 'The scale of parameter %s is not a float.' % (
|
|
|
|
|
original_var_name)
|
|
|
|
|
max_range = param_range * act_range / scale_v
|
|
|
|
|
max_range *= param_range / scale_v
|
|
|
|
|
else:
|
|
|
|
|
max_range *= act_range
|
|
|
|
|
assert isinstance(scale_v, IrNode)
|
|
|
|
|
scale_var_node = self._var_scale_map[original_var_name]
|
|
|
|
|
|
|
|
|
@ -850,7 +897,7 @@ class ConvertToInt8Pass(object):
|
|
|
|
|
'The place cannot be set None.'
|
|
|
|
|
self._scope = scope
|
|
|
|
|
self._place = place
|
|
|
|
|
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
|
self._quantizable_ops = _quantizable_op_list
|
|
|
|
|
|
|
|
|
|
def apply(self, graph):
|
|
|
|
|
"""
|
|
|
|
@ -866,6 +913,10 @@ class ConvertToInt8Pass(object):
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
|
if op_name in self._quantizable_ops:
|
|
|
|
|
skipped = op_node.op().has_attr("skip_quant") and \
|
|
|
|
|
op_node.op().attr("skip_quant")
|
|
|
|
|
if skipped:
|
|
|
|
|
continue
|
|
|
|
|
for var_node in op_node.inputs:
|
|
|
|
|
name = var_node.name()
|
|
|
|
|
if name in persistable_vars:
|
|
|
|
@ -924,14 +975,8 @@ class TransformForMobilePass(object):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self._fake_quant_op_names = [
|
|
|
|
|
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
|
|
|
|
|
'fake_quantize_moving_average_abs_max',
|
|
|
|
|
'fake_channel_wise_quantize_abs_max'
|
|
|
|
|
]
|
|
|
|
|
self._fake_dequant_op_names = [
|
|
|
|
|
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
|
|
|
|
|
]
|
|
|
|
|
self._fake_quant_op_names = _fake_quant_op_list
|
|
|
|
|
self._fake_dequant_op_names = _fake_dequant_op_list
|
|
|
|
|
|
|
|
|
|
def apply(self, graph):
|
|
|
|
|
"""
|
|
|
|
@ -980,12 +1025,7 @@ class ScaleForTrainingPass(object):
|
|
|
|
|
self._place = place
|
|
|
|
|
self._moving_rate = moving_rate
|
|
|
|
|
self._is_test = None
|
|
|
|
|
self._teller_set = [
|
|
|
|
|
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
|
|
|
|
|
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
|
|
|
|
|
"elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
|
|
|
|
|
"conv2d_transpose", "leaky_relu"
|
|
|
|
|
]
|
|
|
|
|
self._teller_set = _out_scale_op_list
|
|
|
|
|
|
|
|
|
|
def apply(self, graph):
|
|
|
|
|
"""
|
|
|
|
@ -1087,12 +1127,7 @@ class ScaleForInferencePass(object):
|
|
|
|
|
scope(fluid.Scope): The scope is used to initialize these new parameters.
|
|
|
|
|
"""
|
|
|
|
|
self._scope = scope
|
|
|
|
|
self._teller_set = [
|
|
|
|
|
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
|
|
|
|
|
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
|
|
|
|
|
"elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
|
|
|
|
|
"conv2d_transpose", "leaky_relu"
|
|
|
|
|
]
|
|
|
|
|
self._teller_set = _out_scale_op_list
|
|
|
|
|
|
|
|
|
|
def apply(self, graph):
|
|
|
|
|
"""
|
|
|
|
@ -1135,7 +1170,7 @@ class AddQuantDequantPass(object):
|
|
|
|
|
self._moving_rate = moving_rate
|
|
|
|
|
self._quant_bits = quant_bits
|
|
|
|
|
self._is_test = None
|
|
|
|
|
self._target_ops = ["elementwise_add", "pool2d"]
|
|
|
|
|
self._target_ops = ["elementwise_add"]
|
|
|
|
|
|
|
|
|
|
def apply(self, graph):
|
|
|
|
|
"""
|
|
|
|
|