|
|
@ -56,7 +56,8 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
]
|
|
|
|
]
|
|
|
|
self._fake_quantize_dequantize_types = [
|
|
|
|
self._fake_quantize_dequantize_types = [
|
|
|
|
'fake_quantize_dequantize_abs_max',
|
|
|
|
'fake_quantize_dequantize_abs_max',
|
|
|
|
'fake_quantize_dequantize_moving_average_abs_max'
|
|
|
|
'fake_quantize_dequantize_moving_average_abs_max',
|
|
|
|
|
|
|
|
'fake_channel_wise_quantize_dequantize_abs_max'
|
|
|
|
]
|
|
|
|
]
|
|
|
|
self._ops_to_quantize = _ops_to_quantize
|
|
|
|
self._ops_to_quantize = _ops_to_quantize
|
|
|
|
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
|
|
|
|
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
|
|
|
@ -71,7 +72,7 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
self._relu_ops = ['relu', 'relu6']
|
|
|
|
self._relu_ops = ['relu', 'relu6']
|
|
|
|
self._matmul_ops = ['matmul']
|
|
|
|
self._matmul_ops = ['matmul']
|
|
|
|
self._gru_ops = ['fusion_gru', 'multi_gru']
|
|
|
|
self._gru_ops = ['fusion_gru', 'multi_gru']
|
|
|
|
self._weight_scales = {}
|
|
|
|
self._weight_thresholds = {}
|
|
|
|
# Collect the Input and Output sclaes from Fake quant models
|
|
|
|
# Collect the Input and Output sclaes from Fake quant models
|
|
|
|
self._var_quant_scales = {}
|
|
|
|
self._var_quant_scales = {}
|
|
|
|
self._max_range = {}
|
|
|
|
self._max_range = {}
|
|
|
@ -84,7 +85,8 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
|
|
|
|
|
|
|
|
self._reset_pass_idx_and_group('int8')
|
|
|
|
self._reset_pass_idx_and_group('int8')
|
|
|
|
graph = self._gather_weight_scales_from_fake(graph)
|
|
|
|
graph = self._label_skip_quantized_op(graph)
|
|
|
|
|
|
|
|
graph = self._gather_weight_thresholds_from_fake(graph)
|
|
|
|
graph = self._gather_output_scales_from_attr(graph)
|
|
|
|
graph = self._gather_output_scales_from_attr(graph)
|
|
|
|
graph = self._gather_input_scales_from_fake(graph)
|
|
|
|
graph = self._gather_input_scales_from_fake(graph)
|
|
|
|
graph = self._remove_fake_ops(graph)
|
|
|
|
graph = self._remove_fake_ops(graph)
|
|
|
@ -135,6 +137,30 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
def _is_fc_quantized(self, graph):
|
|
|
|
def _is_fc_quantized(self, graph):
|
|
|
|
return self._is_any_of_op_types_quantized(self._fc_ops, graph)
|
|
|
|
return self._is_any_of_op_types_quantized(self._fc_ops, graph)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _label_skip_quantized_op(self, graph):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
For some ops(conv2d, depthwise_conv2d, mul, matml), find and label
|
|
|
|
|
|
|
|
the skip quantized ops. cpu_quantize_placement_pass will use the
|
|
|
|
|
|
|
|
label to identify it.
|
|
|
|
|
|
|
|
For static models, the skip quantized ops have `skip_quant` attr.
|
|
|
|
|
|
|
|
Therefore, it only needs to find and label the skip quantized ops for
|
|
|
|
|
|
|
|
dygraph models, in which the quantized ops don't have `quantization_type`
|
|
|
|
|
|
|
|
attr.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
target_ops = self._conv_ops + self._mul_ops + self._matmul_ops
|
|
|
|
|
|
|
|
for op_node in graph.all_op_nodes():
|
|
|
|
|
|
|
|
if op_node.name() in target_ops and \
|
|
|
|
|
|
|
|
not op_node.op().has_attr("quantization_type"):
|
|
|
|
|
|
|
|
is_quantized_op = True
|
|
|
|
|
|
|
|
for var_node in op_node.inputs:
|
|
|
|
|
|
|
|
for front_op_node in var_node.inputs:
|
|
|
|
|
|
|
|
if "fake_quantize_dequantize_" not in front_op_node.name(
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
is_quantized_op = False
|
|
|
|
|
|
|
|
if not is_quantized_op:
|
|
|
|
|
|
|
|
op_node.op()._set_attr("skip_quant", True)
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
def _gather_input_scales_from_fake(self, graph):
|
|
|
|
def _gather_input_scales_from_fake(self, graph):
|
|
|
|
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
|
|
|
|
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
|
|
|
|
scales = self._var_quant_scales
|
|
|
|
scales = self._var_quant_scales
|
|
|
@ -165,19 +191,19 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
def _gather_weight_scales_from_fake(self, graph):
|
|
|
|
def _gather_weight_thresholds_from_fake(self, graph):
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
if op.name() in self._fake_dequantize_types:
|
|
|
|
if op.name() in self._fake_dequantize_types:
|
|
|
|
input_name = op.input("X")[0]
|
|
|
|
input_name = op.input("X")[0]
|
|
|
|
if op.op().has_attr("max_range"):
|
|
|
|
if op.op().has_attr("max_range"):
|
|
|
|
_max_range = np.array(op.op().attr("max_range")).astype(
|
|
|
|
_max_range = np.array(op.op().attr("max_range")).astype(
|
|
|
|
np.float64)
|
|
|
|
np.float64)
|
|
|
|
self._weight_scales[input_name] = np.array(
|
|
|
|
self._weight_thresholds[input_name] = np.array(
|
|
|
|
self._s8_max * self._s8_max /
|
|
|
|
self._s8_max * self._s8_max /
|
|
|
|
_max_range).astype(np.float64)
|
|
|
|
_max_range).astype(np.float64)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
scale_name = op.input("Scales")[0]
|
|
|
|
scale_name = op.input("Scales")[0]
|
|
|
|
self._weight_scales[input_name] = np.array(
|
|
|
|
self._weight_thresholds[input_name] = np.array(
|
|
|
|
self._load_param(self._scope, scale_name)).astype(
|
|
|
|
self._load_param(self._scope, scale_name)).astype(
|
|
|
|
np.float64)
|
|
|
|
np.float64)
|
|
|
|
|
|
|
|
|
|
|
@ -314,7 +340,7 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
weight_var_name = op_node.input(weight_name)[0]
|
|
|
|
weight_var_name = op_node.input(weight_name)[0]
|
|
|
|
output_var_name = op_node.output(output_name)[0]
|
|
|
|
output_var_name = op_node.output(output_name)[0]
|
|
|
|
# Convert int8 range weights to fp32 range weights
|
|
|
|
# Convert int8 range weights to fp32 range weights
|
|
|
|
scales = self._weight_scales[output_var_name]
|
|
|
|
scales = self._weight_thresholds[output_var_name]
|
|
|
|
weight = self._load_param(self._scope, weight_var_name)
|
|
|
|
weight = self._load_param(self._scope, weight_var_name)
|
|
|
|
if scales.size == 1 or scales.size == weight.shape[0]:
|
|
|
|
if scales.size == 1 or scales.size == weight.shape[0]:
|
|
|
|
w_fp32 = np.multiply(np.divide(weight, self._s8_max).T, scales.T).T
|
|
|
|
w_fp32 = np.multiply(np.divide(weight, self._s8_max).T, scales.T).T
|
|
|
|