|
|
|
@ -43,7 +43,8 @@ class TestWeightQuantization(unittest.TestCase):
|
|
|
|
|
os.system(cmd)
|
|
|
|
|
|
|
|
|
|
def run_test(self, model_name, model_data_url, model_data_md5, weight_bits,
|
|
|
|
|
quantizable_op_type, threshold_rate):
|
|
|
|
|
quantizable_op_type, weight_quantize_type, generate_test_model,
|
|
|
|
|
threshold_rate):
|
|
|
|
|
|
|
|
|
|
model_dir = self.download_model(model_name, model_data_url,
|
|
|
|
|
model_data_md5)
|
|
|
|
@ -57,6 +58,8 @@ class TestWeightQuantization(unittest.TestCase):
|
|
|
|
|
save_model_dir=save_model_dir,
|
|
|
|
|
weight_bits=weight_bits,
|
|
|
|
|
quantizable_op_type=quantizable_op_type,
|
|
|
|
|
weight_quantize_type=weight_quantize_type,
|
|
|
|
|
generate_test_model=generate_test_model,
|
|
|
|
|
threshold_rate=threshold_rate)
|
|
|
|
|
print("finish weight quantization for " + model_name + "\n")
|
|
|
|
|
|
|
|
|
@ -72,19 +75,45 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
|
|
|
|
|
model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz"
|
|
|
|
|
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
|
|
|
|
|
|
|
|
|
|
def test_weight_quantization_mobilenetv1_8bit(self):
|
|
|
|
|
def test_weight_quantization_mobilenetv1_8bit_abs_max(self):
|
|
|
|
|
weight_bits = 8
|
|
|
|
|
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
|
weight_quantize_type = "abs_max"
|
|
|
|
|
generate_test_model = True
|
|
|
|
|
threshold_rate = 0.0
|
|
|
|
|
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
|
|
|
|
|
weight_bits, quantizable_op_type, threshold_rate)
|
|
|
|
|
weight_bits, quantizable_op_type, weight_quantize_type,
|
|
|
|
|
generate_test_model, threshold_rate)
|
|
|
|
|
|
|
|
|
|
def test_weight_quantization_mobilenetv1_16bit(self):
|
|
|
|
|
def test_weight_quantization_mobilenetv1_8bit_channel_wise_abs_max(self):
|
|
|
|
|
weight_bits = 8
|
|
|
|
|
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
|
weight_quantize_type = "channel_wise_abs_max"
|
|
|
|
|
generate_test_model = True
|
|
|
|
|
threshold_rate = 0.0
|
|
|
|
|
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
|
|
|
|
|
weight_bits, quantizable_op_type, weight_quantize_type,
|
|
|
|
|
generate_test_model, threshold_rate)
|
|
|
|
|
|
|
|
|
|
def test_weight_quantization_mobilenetv1_16bit_abs_max(self):
|
|
|
|
|
weight_bits = 16
|
|
|
|
|
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
|
weight_quantize_type = "abs_max"
|
|
|
|
|
generate_test_model = False
|
|
|
|
|
threshold_rate = 1e-9
|
|
|
|
|
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
|
|
|
|
|
weight_bits, quantizable_op_type, weight_quantize_type,
|
|
|
|
|
generate_test_model, threshold_rate)
|
|
|
|
|
|
|
|
|
|
def test_weight_quantization_mobilenetv1_16bit_channel_wise_abs_max(self):
|
|
|
|
|
weight_bits = 16
|
|
|
|
|
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
|
|
|
|
|
weight_quantize_type = "channel_wise_abs_max"
|
|
|
|
|
generate_test_model = False
|
|
|
|
|
threshold_rate = 1e-9
|
|
|
|
|
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
|
|
|
|
|
weight_bits, quantizable_op_type, threshold_rate)
|
|
|
|
|
weight_bits, quantizable_op_type, weight_quantize_type,
|
|
|
|
|
generate_test_model, threshold_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|