|
|
@ -29,6 +29,7 @@ from .quantization_pass import _out_scale_op_list
|
|
|
|
from .quantization_pass import _get_op_input_var_names
|
|
|
|
from .quantization_pass import _get_op_input_var_names
|
|
|
|
from .quantization_pass import _get_op_output_var_names
|
|
|
|
from .quantization_pass import _get_op_output_var_names
|
|
|
|
from .quantization_pass import _get_output_name_index
|
|
|
|
from .quantization_pass import _get_output_name_index
|
|
|
|
|
|
|
|
from .quantization_pass import _get_input_name_index
|
|
|
|
from .quantization_pass import _channelwise_quant_axis1_ops
|
|
|
|
from .quantization_pass import _channelwise_quant_axis1_ops
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['PostTrainingQuantization', 'WeightQuantization']
|
|
|
|
__all__ = ['PostTrainingQuantization', 'WeightQuantization']
|
|
|
@ -253,9 +254,11 @@ class PostTrainingQuantization(object):
|
|
|
|
]
|
|
|
|
]
|
|
|
|
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
|
|
|
|
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
|
|
|
|
self._support_algo_type = ['KL', 'abs_max', 'min_max']
|
|
|
|
self._support_algo_type = ['KL', 'abs_max', 'min_max']
|
|
|
|
|
|
|
|
self._dynamic_quantize_op_type = ['lstm']
|
|
|
|
self._support_quantize_op_type = \
|
|
|
|
self._support_quantize_op_type = \
|
|
|
|
list(set(QuantizationTransformPass._supported_quantizable_op_type +
|
|
|
|
list(set(QuantizationTransformPass._supported_quantizable_op_type +
|
|
|
|
AddQuantDequantPass._supported_quantizable_op_type))
|
|
|
|
AddQuantDequantPass._supported_quantizable_op_type +
|
|
|
|
|
|
|
|
self._dynamic_quantize_op_type))
|
|
|
|
|
|
|
|
|
|
|
|
# Check inputs
|
|
|
|
# Check inputs
|
|
|
|
assert executor is not None, "The executor cannot be None."
|
|
|
|
assert executor is not None, "The executor cannot be None."
|
|
|
@ -381,6 +384,10 @@ class PostTrainingQuantization(object):
|
|
|
|
self._save_input_threhold()
|
|
|
|
self._save_input_threhold()
|
|
|
|
|
|
|
|
|
|
|
|
self._save_output_threshold()
|
|
|
|
self._save_output_threshold()
|
|
|
|
|
|
|
|
if any(op_type in self._quantizable_op_type
|
|
|
|
|
|
|
|
for op_type in self._dynamic_quantize_op_type):
|
|
|
|
|
|
|
|
self._collect_dynamic_quantize_op_threshold(
|
|
|
|
|
|
|
|
self._dynamic_quantize_op_type)
|
|
|
|
return self._program
|
|
|
|
return self._program
|
|
|
|
|
|
|
|
|
|
|
|
def save_quantized_model(self,
|
|
|
|
def save_quantized_model(self,
|
|
|
@ -776,6 +783,34 @@ class PostTrainingQuantization(object):
|
|
|
|
for var_name in out_var_names:
|
|
|
|
for var_name in out_var_names:
|
|
|
|
analysis_and_save_info(op, var_name)
|
|
|
|
analysis_and_save_info(op, var_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _collect_dynamic_quantize_op_threshold(self, target_ops_type):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Collect and save the weight threshold for dynamic quantize ops,
|
|
|
|
|
|
|
|
such as lstm and gru.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
target_ops_type(list): the op type of target ops
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
None
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_ops = []
|
|
|
|
|
|
|
|
for index in range(self._program.num_blocks):
|
|
|
|
|
|
|
|
for op in self._program.block(index).ops:
|
|
|
|
|
|
|
|
if op.type in target_ops_type:
|
|
|
|
|
|
|
|
target_ops.append(op)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quantization_type = str("post_" + self._algo).lower()
|
|
|
|
|
|
|
|
persistable_var_names = _all_persistable_var_names(self._program)
|
|
|
|
|
|
|
|
for op in target_ops:
|
|
|
|
|
|
|
|
for var_name in _get_op_input_var_names(op):
|
|
|
|
|
|
|
|
if var_name in persistable_var_names:
|
|
|
|
|
|
|
|
var_data = _load_variable_data(self._scope, var_name)
|
|
|
|
|
|
|
|
threshold = float(np.max(np.abs(var_data)))
|
|
|
|
|
|
|
|
argname, index = _get_input_name_index(op, var_name)
|
|
|
|
|
|
|
|
op._set_attr(argname + str(index) + "_threshold", threshold)
|
|
|
|
|
|
|
|
op._set_attr("quantization_type", quantization_type)
|
|
|
|
|
|
|
|
op._set_attr("bit_length", self._weight_bits)
|
|
|
|
|
|
|
|
|
|
|
|
def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255):
|
|
|
|
def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255):
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
Using the KL-divergenc method to get the more precise scaling factor.
|
|
|
|
Using the KL-divergenc method to get the more precise scaling factor.
|
|
|
|