|
|
|
@ -28,6 +28,7 @@ from .quantization_pass import AddQuantDequantPass
|
|
|
|
|
from .quantization_pass import _out_scale_op_list
|
|
|
|
|
from .quantization_pass import _get_op_input_var_names
|
|
|
|
|
from .quantization_pass import _get_op_output_var_names
|
|
|
|
|
from .quantization_pass import _get_output_name_index
|
|
|
|
|
|
|
|
|
|
__all__ = ['PostTrainingQuantization', 'WeightQuantization']
|
|
|
|
|
|
|
|
|
@ -405,6 +406,10 @@ class PostTrainingQuantization(object):
|
|
|
|
|
model_filename=self._model_filename,
|
|
|
|
|
params_filename=self._params_filename)
|
|
|
|
|
|
|
|
|
|
if self._program.num_blocks > 1:
|
|
|
|
|
_logger.error("The post training quantization requires that the "
|
|
|
|
|
"program only has one block.")
|
|
|
|
|
|
|
|
|
|
if self._optimize_model:
|
|
|
|
|
self._optimize_fp32_model()
|
|
|
|
|
|
|
|
|
@ -450,6 +455,9 @@ class PostTrainingQuantization(object):
|
|
|
|
|
persistable_var_names = _all_persistable_var_names(self._program)
|
|
|
|
|
for op in self._program.global_block().ops:
|
|
|
|
|
op_type = op.type
|
|
|
|
|
if self._is_full_quantize and \
|
|
|
|
|
op_type not in self._quantizable_op_type:
|
|
|
|
|
_logger.warning(op_type + " is not supported for quantization.")
|
|
|
|
|
# For quantized ops, sample inputs and outputs
|
|
|
|
|
if op_type in self._quantizable_op_type:
|
|
|
|
|
collect_var_name(
|
|
|
|
@ -685,13 +693,25 @@ class PostTrainingQuantization(object):
|
|
|
|
|
op._set_attr("quantization_type", quantized_type)
|
|
|
|
|
|
|
|
|
|
def analysis_and_save_info(op_node, out_var_name):
|
|
|
|
|
argname_index = _get_output_name_index(op_node, out_var_name)
|
|
|
|
|
assert argname_index is not None, \
|
|
|
|
|
out_var_name + " is not the output of the op"
|
|
|
|
|
if self._algo == "KL":
|
|
|
|
|
# For compatibility, we save output threshold by two methods.
|
|
|
|
|
save_info(op_node, out_var_name,
|
|
|
|
|
self._quantized_var_kl_threshold, "out_threshold",
|
|
|
|
|
"post_kl")
|
|
|
|
|
save_info(
|
|
|
|
|
op_node, out_var_name, self._quantized_var_kl_threshold,
|
|
|
|
|
argname_index[0] + str(argname_index[1]) + "_threshold",
|
|
|
|
|
"post_kl")
|
|
|
|
|
elif self._algo == "abs_max":
|
|
|
|
|
save_info(op_node, out_var_name, self._quantized_var_abs_max,
|
|
|
|
|
"out_threshold", "post_abs_max")
|
|
|
|
|
save_info(
|
|
|
|
|
op_node, out_var_name, self._quantized_var_abs_max,
|
|
|
|
|
argname_index[0] + str(argname_index[1]) + "_threshold",
|
|
|
|
|
"post_kl")
|
|
|
|
|
elif self._algo == "min_max":
|
|
|
|
|
save_info(op_node, out_var_name, self._quantized_var_min,
|
|
|
|
|
"out_min", "post_min_max")
|
|
|
|
|