|
|
|
@ -38,7 +38,8 @@ class QuantizationTransformPass(object):
|
|
|
|
|
activation_bits=8,
|
|
|
|
|
activation_quantize_type='abs_max',
|
|
|
|
|
weight_quantize_type='abs_max',
|
|
|
|
|
window_size=10000):
|
|
|
|
|
window_size=10000,
|
|
|
|
|
moving_rate=0.9):
|
|
|
|
|
"""
|
|
|
|
|
Convert and rewrite the IrGraph according to weight and
|
|
|
|
|
activation quantization type.
|
|
|
|
@ -83,19 +84,22 @@ class QuantizationTransformPass(object):
|
|
|
|
|
self._weight_bits = weight_bits
|
|
|
|
|
self._activation_bits = activation_bits
|
|
|
|
|
|
|
|
|
|
quant_type = ['abs_max', 'range_abs_max']
|
|
|
|
|
quant_type = ['abs_max', 'range_abs_max', 'moving_average_abs_max']
|
|
|
|
|
if activation_quantize_type not in quant_type:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Unknown activation_quantize_type : '%s'. It can only be ",
|
|
|
|
|
"'abs_max' or 'range_abs_max'.", str(activation_quantize_type))
|
|
|
|
|
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
|
|
|
|
|
str(activation_quantize_type))
|
|
|
|
|
if weight_quantize_type not in quant_type:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Unknown weight_quantize_type: '%s'. It can only be ",
|
|
|
|
|
"'abs_max' or 'range_abs_max'.", str(weight_quantize_type))
|
|
|
|
|
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
|
|
|
|
|
str(weight_quantize_type))
|
|
|
|
|
|
|
|
|
|
self._activation_quantize_type = activation_quantize_type
|
|
|
|
|
self._weight_quantize_type = weight_quantize_type
|
|
|
|
|
self._window_size = window_size
|
|
|
|
|
self._moving_rate = moving_rate
|
|
|
|
|
|
|
|
|
|
self._need_initialized = collections.OrderedDict()
|
|
|
|
|
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
@ -222,6 +226,9 @@ class QuantizationTransformPass(object):
|
|
|
|
|
elif quant_type == 'range_abs_max':
|
|
|
|
|
return self._insert_quant_range_abs_max_op(graph, var_node,
|
|
|
|
|
quant_bits)
|
|
|
|
|
elif quant_type == 'moving_average_abs_max':
|
|
|
|
|
return self._insert_quant_moving_average_abs_max_op(graph, var_node,
|
|
|
|
|
quant_bits)
|
|
|
|
|
|
|
|
|
|
def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
|
|
|
|
|
"""
|
|
|
|
@ -309,6 +316,74 @@ class QuantizationTransformPass(object):
|
|
|
|
|
|
|
|
|
|
return quant_var_node, scale_out_node
|
|
|
|
|
|
|
|
|
|
def _insert_quant_moving_average_abs_max_op(self, graph, var_node,
|
|
|
|
|
quant_bits):
|
|
|
|
|
"""Insert fake_quantize_moving_average_abs_max
|
|
|
|
|
"""
|
|
|
|
|
quant_var_node = graph.create_var_node(
|
|
|
|
|
name=self._quantized_var_name(var_node.name()),
|
|
|
|
|
var_type=var_node.type(),
|
|
|
|
|
shape=var_node.shape(),
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
scale_in_node = graph.create_persistable_node(
|
|
|
|
|
name=self._quantized_scale_name(var_node.name()),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
shape=[1],
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
|
|
|
|
|
|
|
|
|
|
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
|
|
|
|
|
ins = {'X': var_node, 'InScale': scale_in_node}
|
|
|
|
|
outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
|
|
|
|
|
if not self._is_test:
|
|
|
|
|
state_in_node = graph.create_persistable_node(
|
|
|
|
|
name=unique_name.generate('state'),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
var_dtype=var_node.dtype(),
|
|
|
|
|
shape=[1])
|
|
|
|
|
self._need_initialized[state_in_node.var()] = Constant(value=1)
|
|
|
|
|
accum_in_node = graph.create_persistable_node(
|
|
|
|
|
name=unique_name.generate('accum'),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
var_dtype=var_node.dtype(),
|
|
|
|
|
shape=[1])
|
|
|
|
|
self._need_initialized[accum_in_node.var()] = Constant(value=1)
|
|
|
|
|
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
|
|
|
|
|
))
|
|
|
|
|
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
ins['InState'] = state_in_node
|
|
|
|
|
ins['InAccum'] = accum_in_node
|
|
|
|
|
outs['OutState'] = state_out_node
|
|
|
|
|
outs['OutAccum'] = accum_out_node
|
|
|
|
|
|
|
|
|
|
attrs = {
|
|
|
|
|
'bit_length': quant_bits,
|
|
|
|
|
'moving_rate': self._moving_rate,
|
|
|
|
|
'is_test': self._is_test,
|
|
|
|
|
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
quant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='fake_quantize_moving_average_abs_max',
|
|
|
|
|
attrs=attrs,
|
|
|
|
|
inputs=ins,
|
|
|
|
|
outputs=outs)
|
|
|
|
|
|
|
|
|
|
graph.link_to(var_node, quant_op_node)
|
|
|
|
|
graph.link_to(scale_in_node, quant_op_node)
|
|
|
|
|
graph.link_to(quant_op_node, quant_var_node)
|
|
|
|
|
graph.link_to(quant_op_node, scale_out_node)
|
|
|
|
|
|
|
|
|
|
if not self._is_test:
|
|
|
|
|
graph.link_to(state_in_node, quant_op_node)
|
|
|
|
|
graph.link_to(accum_in_node, quant_op_node)
|
|
|
|
|
graph.link_to(quant_op_node, state_out_node)
|
|
|
|
|
graph.link_to(quant_op_node, accum_out_node)
|
|
|
|
|
|
|
|
|
|
return quant_var_node, scale_out_node
|
|
|
|
|
|
|
|
|
|
def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
|
|
|
|
|
"""
|
|
|
|
|
Insert fake_dequantize_op in the graph.
|
|
|
|
@ -389,7 +464,8 @@ class QuantizationFreezePass(object):
|
|
|
|
|
self._weight_quantize_type = weight_quantize_type
|
|
|
|
|
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
|
self._fake_quant_op_names = [
|
|
|
|
|
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
|
|
|
|
|
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
|
|
|
|
|
'fake_quantize_moving_average_abs_max'
|
|
|
|
|
]
|
|
|
|
|
self._fake_dequant_op_names = ['fake_dequantize_max_abs']
|
|
|
|
|
self._op_input_rename_map = collections.OrderedDict()
|
|
|
|
|