|
|
|
@ -45,16 +45,14 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
self._place = _place
|
|
|
|
|
self._core = _core
|
|
|
|
|
self._debug = _debug
|
|
|
|
|
self._quantize_types = [
|
|
|
|
|
self._fake_quantize_types = [
|
|
|
|
|
'fake_quantize_moving_average_abs_max',
|
|
|
|
|
'fake_quantize_range_abs_max',
|
|
|
|
|
'fake_quantize_dequantize_moving_average_abs_max'
|
|
|
|
|
]
|
|
|
|
|
self._fake_quantize_types = [
|
|
|
|
|
'fake_quantize_moving_average_abs_max',
|
|
|
|
|
'fake_quantize_dequantize_moving_average_abs_max'
|
|
|
|
|
self._fake_dequantize_types = [
|
|
|
|
|
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
|
|
|
|
|
]
|
|
|
|
|
self._fake_dequantize_types = ['fake_dequantize_max_abs']
|
|
|
|
|
self._quantized_ops = _quantized_ops
|
|
|
|
|
self._scale_immutable_ops = [
|
|
|
|
|
'transpose2', 'reshape2', 'pool2d', 'scale'
|
|
|
|
@ -74,7 +72,9 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
assert isinstance(graph,
|
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
|
|
|
|
|
|
graph = self._gather_scales(graph)
|
|
|
|
|
graph = self._gather_weight_scales_from_fake(graph)
|
|
|
|
|
graph = self._gather_output_scales_from_attr(graph)
|
|
|
|
|
graph = self._gather_input_scales_from_fake(graph)
|
|
|
|
|
graph = self._remove_fake_ops(graph)
|
|
|
|
|
graph = self._dequantize_weights(graph)
|
|
|
|
|
graph = self._optimize_fp32_graph(graph)
|
|
|
|
@ -83,6 +83,7 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
graph = self._propagate_scales(graph)
|
|
|
|
|
graph = self._set_dummy_out_scales(graph)
|
|
|
|
|
graph = self._quantize_fp32_graph(graph)
|
|
|
|
|
graph = self._optimize_int8_graph(graph)
|
|
|
|
|
graph = self._cleanup(graph)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
@ -90,9 +91,6 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
assert isinstance(graph,
|
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
|
|
|
|
|
|
graph = self._gather_scales(graph)
|
|
|
|
|
graph = self._remove_fake_ops(graph)
|
|
|
|
|
graph = self._dequantize_weights(graph)
|
|
|
|
|
graph = self._optimize_fp32_graph(graph)
|
|
|
|
|
graph = self._cleanup(graph)
|
|
|
|
|
return graph
|
|
|
|
@ -108,29 +106,61 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
def _is_fc_quantized(self):
|
|
|
|
|
return 'fc' in self._quantized_ops
|
|
|
|
|
|
|
|
|
|
def _gather_scales(self, graph):
|
|
|
|
|
def _gather_input_scales_from_fake(self, graph):
|
|
|
|
|
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
|
|
|
|
|
scales = self._var_quant_scales
|
|
|
|
|
for var_name in var_names:
|
|
|
|
|
scales[var_name] = (use_unsigned_int, lod_tensor)
|
|
|
|
|
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name() in self._quantize_types:
|
|
|
|
|
if op.name() in self._fake_quantize_types:
|
|
|
|
|
bit_length = op.op().attr("bit_length")
|
|
|
|
|
assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format(
|
|
|
|
|
bit_length)
|
|
|
|
|
|
|
|
|
|
input_name = op.input("X")[0]
|
|
|
|
|
scale_name = op.input("InScale")[0]
|
|
|
|
|
output_name = op.output("Out")[0]
|
|
|
|
|
# Gather new weights scale after folding batchnorm in convolution
|
|
|
|
|
scale = np.array(1.0 / self._load_param(
|
|
|
|
|
self._scope, scale_name)[0]).astype(np.float64)
|
|
|
|
|
lod_tensor = self._convert_scale2tensor(scale)
|
|
|
|
|
use_unsigned_int = False
|
|
|
|
|
self._var_quant_scales[input_name] = (use_unsigned_int,
|
|
|
|
|
lod_tensor)
|
|
|
|
|
self._var_quant_scales[scale_name.replace(".scale", "")] = (
|
|
|
|
|
use_unsigned_int, lod_tensor)
|
|
|
|
|
_add_scale_for_vars([input_name, output_name], use_unsigned_int,
|
|
|
|
|
lod_tensor)
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _gather_weight_scales_from_fake(self, graph):
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name() in self._fake_dequantize_types:
|
|
|
|
|
input_name = op.input("X")[0]
|
|
|
|
|
_max_range = op.op().attr("max_range")
|
|
|
|
|
self._weight_scales[input_name] = _max_range
|
|
|
|
|
if op.op().has_attr("max_range"):
|
|
|
|
|
_max_range = np.array(op.op().attr("max_range")).astype(
|
|
|
|
|
np.float64)
|
|
|
|
|
self._weight_scales[input_name] = _max_range
|
|
|
|
|
else:
|
|
|
|
|
scale_name = op.input("Scales")[0]
|
|
|
|
|
scale = np.array(
|
|
|
|
|
self._s8_max * self._s8_max / self._load_param(
|
|
|
|
|
self._scope, scale_name)).astype(np.float64)
|
|
|
|
|
self._weight_scales[input_name] = scale
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _gather_output_scales_from_attr(self, graph):
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.op().has_attr("out_threshold"):
|
|
|
|
|
attr_scale = op.op().attr("out_threshold")
|
|
|
|
|
if attr_scale == 0.0: continue
|
|
|
|
|
scale = np.array(1.0 / attr_scale).astype(np.float64)
|
|
|
|
|
scale_lod_tensor = self._convert_scale2tensor(scale)
|
|
|
|
|
use_unsigned_int = False
|
|
|
|
|
for output_name in op.op().outputs():
|
|
|
|
|
for out_var_name in op.op().output(output_name):
|
|
|
|
|
self._var_quant_scales[out_var_name] = (
|
|
|
|
|
use_unsigned_int, scale_lod_tensor)
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _propagate_scales(self, graph):
|
|
|
|
@ -274,29 +304,24 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
def _dequantize_weights(self, graph):
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name() in self._conv_ops:
|
|
|
|
|
self._dequantize_conv_weights(graph, op)
|
|
|
|
|
self._dequantize_op_weights(graph, op, "Filter", "Output")
|
|
|
|
|
elif self._is_fc_quantized() and op.name() in self._mul_ops:
|
|
|
|
|
self._dequantize_mul_weights(graph, op)
|
|
|
|
|
self._dequantize_op_weights(graph, op, "Y", "Out")
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _dequantize_conv_weights(self, graph, op_node):
|
|
|
|
|
weight_name = op_node.input("Filter")[0]
|
|
|
|
|
output_name = op_node.output("Output")[0]
|
|
|
|
|
def _dequantize_op_weights(self, graph, op_node, weight_name, output_name):
|
|
|
|
|
weight_var_name = op_node.input(weight_name)[0]
|
|
|
|
|
output_var_name = op_node.output(output_name)[0]
|
|
|
|
|
# Convert int8 range weights to fp32 range weights
|
|
|
|
|
scales = self._weight_scales[output_name]
|
|
|
|
|
weight = self._load_param(self._scope, weight_name)
|
|
|
|
|
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
|
|
|
|
|
w_fp32 = w_fp32.reshape(weight.shape)
|
|
|
|
|
self._restore_var(weight_name, w_fp32)
|
|
|
|
|
|
|
|
|
|
def _dequantize_mul_weights(self, graph, op_node):
|
|
|
|
|
weight_name = op_node.input("Y")[0]
|
|
|
|
|
output_name = op_node.output("Out")[0]
|
|
|
|
|
scales = self._weight_scales[output_name]
|
|
|
|
|
weight = self._load_param(self._scope, weight_name)
|
|
|
|
|
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
|
|
|
|
|
w_fp32 = w_fp32.reshape(weight.shape)
|
|
|
|
|
self._restore_var(weight_name, w_fp32)
|
|
|
|
|
scales = self._weight_scales[output_var_name]
|
|
|
|
|
weight = self._load_param(self._scope, weight_var_name)
|
|
|
|
|
assert scales.size == 1 or scales.size == len(
|
|
|
|
|
weight
|
|
|
|
|
), "The size of weight scales vector ({}) does not match the number of output channels ({}) in the weights tensor {}.".format(
|
|
|
|
|
scales.size, len(weight), weight_var_name)
|
|
|
|
|
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
|
|
|
|
|
w_fp32 = w_fp32.reshape(weight.shape).astype(np.float32)
|
|
|
|
|
self._restore_var(weight_var_name, w_fp32)
|
|
|
|
|
|
|
|
|
|
def _restore_var(self, name, array):
|
|
|
|
|
tensor = self._scope.find_var(name).get_tensor()
|
|
|
|
@ -366,11 +391,14 @@ class Qat2Int8MkldnnPass(object):
|
|
|
|
|
self._remove_unused_var_nodes(graph)
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _cleanup(self, graph):
|
|
|
|
|
def _optimize_int8_graph(self, graph):
|
|
|
|
|
# remove dropout ops
|
|
|
|
|
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
|
|
|
|
|
# make some MKL-DNN ops working inplace
|
|
|
|
|
graph = self._apply_pass(graph, 'mkldnn_inplace_pass')
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _cleanup(self, graph):
|
|
|
|
|
graph = self._remove_unused_var_nodes(graph)
|
|
|
|
|
graph = self._set_op_role_forward(graph)
|
|
|
|
|
return graph
|
|
|
|
|