Weight quantization support channel_wise_abs_max method to achieve higher accuracy (#23629)

* Weight quantization support channel_wise_abs_max method to achieve higher accuracy
revert-22778-infer_var_type
cc 5 years ago committed by GitHub
parent 1747bbdbab
commit 40aa14ec77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,7 +43,8 @@ class TestWeightQuantization(unittest.TestCase):
os.system(cmd)
def run_test(self, model_name, model_data_url, model_data_md5, weight_bits,
quantizable_op_type, threshold_rate):
quantizable_op_type, weight_quantize_type, generate_test_model,
threshold_rate):
model_dir = self.download_model(model_name, model_data_url,
model_data_md5)
@ -57,6 +58,8 @@ class TestWeightQuantization(unittest.TestCase):
save_model_dir=save_model_dir,
weight_bits=weight_bits,
quantizable_op_type=quantizable_op_type,
weight_quantize_type=weight_quantize_type,
generate_test_model=generate_test_model,
threshold_rate=threshold_rate)
print("finish weight quantization for " + model_name + "\n")
@ -72,19 +75,45 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz"
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
def test_weight_quantization_mobilenetv1_8bit(self):
def test_weight_quantization_mobilenetv1_8bit_abs_max(self):
weight_bits = 8
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
weight_quantize_type = "abs_max"
generate_test_model = True
threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
weight_bits, quantizable_op_type, threshold_rate)
weight_bits, quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit(self):
def test_weight_quantization_mobilenetv1_8bit_channel_wise_abs_max(self):
weight_bits = 8
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
weight_quantize_type = "channel_wise_abs_max"
generate_test_model = True
threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
weight_bits, quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit_abs_max(self):
weight_bits = 16
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
weight_quantize_type = "abs_max"
generate_test_model = False
threshold_rate = 1e-9
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
weight_bits, quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit_channel_wise_abs_max(self):
weight_bits = 16
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
weight_quantize_type = "channel_wise_abs_max"
generate_test_model = False
threshold_rate = 1e-9
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
weight_bits, quantizable_op_type, threshold_rate)
weight_bits, quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
if __name__ == '__main__':

Loading…
Cancel
Save