Post_training_quantizaion supports min_max methon (#23078)

* Post_training_quantizaion supports min_max methon
revert-23830-2.0-beta
cc 6 years ago committed by GitHub
parent 194a22c5a8
commit 589cd8782f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file):
is_full_quantize, is_use_cache_file, diff_threshold):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, 0.025)
self.assertLess(delta_value, diff_threshold)
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_mobilenetv1(self):
class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1"
algo = "KL"
data_urls = [
@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = True
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file)
is_full_quantize, is_use_cache_file, diff_threshold)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__':

@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza
class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self):
model = "ResNet-50"
algo = "direct"
algo = "min_max"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file)
is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__':

Loading…
Cancel
Save