|
|
|
@ -49,11 +49,14 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
|
self._fake_quantize_types = [
|
|
|
|
|
'fake_quantize_moving_average_abs_max',
|
|
|
|
|
'fake_quantize_range_abs_max',
|
|
|
|
|
'fake_quantize_dequantize_moving_average_abs_max'
|
|
|
|
|
]
|
|
|
|
|
self._fake_dequantize_types = [
|
|
|
|
|
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
|
|
|
|
|
]
|
|
|
|
|
self._fake_quantize_dequantize_types = [
|
|
|
|
|
'fake_quantize_dequantize_abs_max',
|
|
|
|
|
'fake_quantize_dequantize_moving_average_abs_max'
|
|
|
|
|
]
|
|
|
|
|
self._ops_to_quantize = _ops_to_quantize
|
|
|
|
|
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
|
|
|
|
|
[-1])
|
|
|
|
@ -137,8 +140,12 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
|
for var_name in var_names:
|
|
|
|
|
scales[var_name] = (use_unsigned_int, lod_tensor)
|
|
|
|
|
|
|
|
|
|
# fake_quantize_dequantize_abs_max doesn't have scale value
|
|
|
|
|
fake_ops = ['fake_quantize_dequantize_moving_average_abs_max']
|
|
|
|
|
fake_ops.extend(self._fake_quantize_types)
|
|
|
|
|
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name() in self._fake_quantize_types:
|
|
|
|
|
if op.name() in fake_ops:
|
|
|
|
|
bit_length = op.op().attr("bit_length")
|
|
|
|
|
assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format(
|
|
|
|
|
bit_length)
|
|
|
|
@ -164,14 +171,14 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
|
if op.op().has_attr("max_range"):
|
|
|
|
|
_max_range = np.array(op.op().attr("max_range")).astype(
|
|
|
|
|
np.float64)
|
|
|
|
|
self._weight_scales[input_name] = _max_range
|
|
|
|
|
self._weight_scales[input_name] = np.array(
|
|
|
|
|
self._s8_max * self._s8_max /
|
|
|
|
|
_max_range).astype(np.float64)
|
|
|
|
|
else:
|
|
|
|
|
scale_name = op.input("Scales")[0]
|
|
|
|
|
scales = np.array(
|
|
|
|
|
self._s8_max * self._s8_max / self._load_param(
|
|
|
|
|
self._scope, scale_name)).astype(np.float64)
|
|
|
|
|
scales[scales == np.Inf] = 0.0
|
|
|
|
|
self._weight_scales[input_name] = scales
|
|
|
|
|
self._weight_scales[input_name] = np.array(
|
|
|
|
|
self._load_param(self._scope, scale_name)).astype(
|
|
|
|
|
np.float64)
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
@ -243,9 +250,9 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name() in self._fake_quantize_types:
|
|
|
|
|
self._remove_fake_quantize(graph, op)
|
|
|
|
|
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name() in self._fake_dequantize_types:
|
|
|
|
|
elif op.name() in self._fake_dequantize_types:
|
|
|
|
|
self._remove_fake_dequantize(graph, op)
|
|
|
|
|
elif op.name() in self._fake_quantize_dequantize_types:
|
|
|
|
|
self._remove_fake_dequantize(graph, op)
|
|
|
|
|
|
|
|
|
|
return graph
|
|
|
|
@ -290,10 +297,15 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def _dequantize_weights(self, graph):
|
|
|
|
|
def _is_int8_weights(op_node, weight_name):
|
|
|
|
|
weight_var_name = op_node.input(weight_name)[0]
|
|
|
|
|
weight = self._load_param(self._scope, weight_var_name)
|
|
|
|
|
return np.all(np.mod(weight, 1) == 0)
|
|
|
|
|
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name() in self._conv_ops:
|
|
|
|
|
if op.name() in self._conv_ops and _is_int8_weights(op, "Filter"):
|
|
|
|
|
self._dequantize_op_weights(graph, op, "Filter", "Output")
|
|
|
|
|
elif op.name() in self._mul_ops:
|
|
|
|
|
elif op.name() in self._mul_ops and _is_int8_weights(op, "Y"):
|
|
|
|
|
self._dequantize_op_weights(graph, op, "Y", "Out")
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
@ -304,9 +316,9 @@ class Quant2Int8MkldnnPass(object):
|
|
|
|
|
scales = self._weight_scales[output_var_name]
|
|
|
|
|
weight = self._load_param(self._scope, weight_var_name)
|
|
|
|
|
if scales.size == 1 or scales.size == weight.shape[0]:
|
|
|
|
|
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
|
|
|
|
|
w_fp32 = np.multiply(np.divide(weight, self._s8_max).T, scales.T).T
|
|
|
|
|
elif len(weight.shape) > 1 and scales.size == weight.shape[1]:
|
|
|
|
|
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
|
|
|
|
|
w_fp32 = np.multiply(np.divide(weight, self._s8_max), scales)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}."
|
|
|
|
|