Collect weight threshold for lstm op in post_training_quantization (#28701)

* Collect weight threshold of lstm, test=develop
revert-31068-fix_conv3d_windows
cc 4 years ago committed by GitHub
parent 11e78ebaa3
commit 5d8d463cf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,6 +29,7 @@ from .quantization_pass import _out_scale_op_list
from .quantization_pass import _get_op_input_var_names
from .quantization_pass import _get_op_output_var_names
from .quantization_pass import _get_output_name_index
from .quantization_pass import _get_input_name_index
from .quantization_pass import _channelwise_quant_axis1_ops
__all__ = ['PostTrainingQuantization', 'WeightQuantization']
@ -253,9 +254,11 @@ class PostTrainingQuantization(object):
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = ['KL', 'abs_max', 'min_max']
self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type +
AddQuantDequantPass._supported_quantizable_op_type))
AddQuantDequantPass._supported_quantizable_op_type +
self._dynamic_quantize_op_type))
# Check inputs
assert executor is not None, "The executor cannot be None."
@ -381,6 +384,10 @@ class PostTrainingQuantization(object):
self._save_input_threhold()
self._save_output_threshold()
if any(op_type in self._quantizable_op_type
for op_type in self._dynamic_quantize_op_type):
self._collect_dynamic_quantize_op_threshold(
self._dynamic_quantize_op_type)
return self._program
def save_quantized_model(self,
@ -776,6 +783,34 @@ class PostTrainingQuantization(object):
for var_name in out_var_names:
analysis_and_save_info(op, var_name)
def _collect_dynamic_quantize_op_threshold(self, target_ops_type):
"""
Collect and save the weight threshold for dynamic quantize ops,
such as lstm and gru.
Args:
target_ops_type(list): the op type of target ops
Returns:
None
"""
target_ops = []
for index in range(self._program.num_blocks):
for op in self._program.block(index).ops:
if op.type in target_ops_type:
target_ops.append(op)
quantization_type = str("post_" + self._algo).lower()
persistable_var_names = _all_persistable_var_names(self._program)
for op in target_ops:
for var_name in _get_op_input_var_names(op):
if var_name in persistable_var_names:
var_data = _load_variable_data(self._scope, var_name)
threshold = float(np.max(np.abs(var_data)))
argname, index = _get_input_name_index(op, var_name)
op._set_attr(argname + str(index) + "_threshold", threshold)
op._set_attr("quantization_type", quantization_type)
op._set_attr("bit_length", self._weight_bits)
def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255):
'''
Using the KL-divergenc method to get the more precise scaling factor.

@ -120,6 +120,7 @@ _op_real_in_out_name = {
"hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]],
"lstm": [["Input", "Weight"], ["Hidden"]],
}
_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose']
@ -144,6 +145,21 @@ def _get_op_input_var_names(op):
return var_names
def _get_input_name_index(op, input_var_name):
"""Get the input name and index of the var_name in the op"""
assert isinstance(op, (IrNode, Operator)), \
"The input op should be IrNode or Operator."
op_name = op.name() if isinstance(op, IrNode) \
else op.type
res = None
for argname in _op_real_in_out_name[op_name][0]:
var_names = op.input(argname)
for index, name in enumerate(var_names):
if name == input_var_name:
res = (argname, index)
return res
def _get_op_output_var_names(op):
""" """
assert isinstance(op, (IrNode, Operator)), \

@ -124,6 +124,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2)
endif()
@ -300,6 +301,7 @@ endforeach()
# setting timeout value for old unittests
if(NOT WIN32)
set_tests_properties(test_post_training_quantization_lstm_model PROPERTIES TIMEOUT 120)
set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 400 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 400 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120)

Loading…
Cancel
Save