Add weight quantization in post_training_quanzitaion (#22445)
* support weight quantization in post_training_quanzitaion, test=develop * add test for weight quantization, test=developrevert-22710-feature/integrated_ps_api
parent
dcfb603897
commit
197913ebe1
@ -0,0 +1,91 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import time
|
||||
from paddle.dataset.common import download, DATA_HOME
|
||||
from paddle.fluid.contrib.slim.quantization import WeightQuantization
|
||||
|
||||
|
||||
class TestWeightQuantization(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.weight_quantization_dir = 'weight_quantization'
|
||||
self.cache_folder = os.path.join(DATA_HOME,
|
||||
self.weight_quantization_dir)
|
||||
|
||||
def download_model(self, model_name, data_url, data_md5):
|
||||
download(data_url, self.weight_quantization_dir, data_md5)
|
||||
file_name = data_url.split('/')[-1]
|
||||
file_path = os.path.join(self.cache_folder, file_name)
|
||||
print(model_name + ' is downloaded at ' + file_path)
|
||||
|
||||
unziped_path = os.path.join(self.cache_folder, model_name)
|
||||
self.cache_unzipping(unziped_path, file_path)
|
||||
print(model_name + ' is unziped at ' + unziped_path)
|
||||
return unziped_path
|
||||
|
||||
def cache_unzipping(self, target_folder, zip_path):
|
||||
if not os.path.exists(target_folder):
|
||||
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder,
|
||||
zip_path)
|
||||
os.system(cmd)
|
||||
|
||||
def run_test(self, model_name, model_data_url, model_data_md5,
|
||||
quantize_weight_bits, quantizable_op_type, threshold_rate):
|
||||
|
||||
model_dir = self.download_model(model_name, model_data_url,
|
||||
model_data_md5)
|
||||
|
||||
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
|
||||
save_model_dir = os.path.join(
|
||||
os.getcwd(),
|
||||
model_name + "_wq_" + str(quantize_weight_bits) + "_" + timestamp)
|
||||
weight_quant = WeightQuantization(model_dir=model_dir + "/model")
|
||||
weight_quant.quantize_weight_to_int(
|
||||
save_model_dir=save_model_dir,
|
||||
quantize_weight_bits=quantize_weight_bits,
|
||||
quantizable_op_type=quantizable_op_type,
|
||||
threshold_rate=threshold_rate)
|
||||
print("finish weight quantization for " + model_name + "\n")
|
||||
|
||||
try:
|
||||
os.system("rm -rf {}".format(save_model_dir))
|
||||
except Exception as e:
|
||||
print("Failed to delete {} due to {}".format(save_model_dir, str(
|
||||
e)))
|
||||
|
||||
|
||||
class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
|
||||
model_name = "mobilenetv1"
|
||||
model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz"
|
||||
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
|
||||
|
||||
def test_weight_quantization_mobilenetv1_8bit(self):
|
||||
quantize_weight_bits = 8
|
||||
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
|
||||
threshold_rate = 0.0
|
||||
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
|
||||
quantize_weight_bits, quantizable_op_type, threshold_rate)
|
||||
|
||||
def test_weight_quantization_mobilenetv1_16bit(self):
|
||||
quantize_weight_bits = 16
|
||||
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
|
||||
threshold_rate = 1e-9
|
||||
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
|
||||
quantize_weight_bits, quantizable_op_type, threshold_rate)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in new issue