Add weight quantization in post_training_quanzitaion (#22445)
* support weight quantization in post_training_quanzitaion, test=develop * add test for weight quantization, test=developrevert-22710-feature/integrated_ps_api
parent
dcfb603897
commit
197913ebe1
@ -0,0 +1,91 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from paddle.dataset.common import download, DATA_HOME
|
||||||
|
from paddle.fluid.contrib.slim.quantization import WeightQuantization
|
||||||
|
|
||||||
|
|
||||||
|
class TestWeightQuantization(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.weight_quantization_dir = 'weight_quantization'
|
||||||
|
self.cache_folder = os.path.join(DATA_HOME,
|
||||||
|
self.weight_quantization_dir)
|
||||||
|
|
||||||
|
def download_model(self, model_name, data_url, data_md5):
|
||||||
|
download(data_url, self.weight_quantization_dir, data_md5)
|
||||||
|
file_name = data_url.split('/')[-1]
|
||||||
|
file_path = os.path.join(self.cache_folder, file_name)
|
||||||
|
print(model_name + ' is downloaded at ' + file_path)
|
||||||
|
|
||||||
|
unziped_path = os.path.join(self.cache_folder, model_name)
|
||||||
|
self.cache_unzipping(unziped_path, file_path)
|
||||||
|
print(model_name + ' is unziped at ' + unziped_path)
|
||||||
|
return unziped_path
|
||||||
|
|
||||||
|
def cache_unzipping(self, target_folder, zip_path):
|
||||||
|
if not os.path.exists(target_folder):
|
||||||
|
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder,
|
||||||
|
zip_path)
|
||||||
|
os.system(cmd)
|
||||||
|
|
||||||
|
def run_test(self, model_name, model_data_url, model_data_md5,
|
||||||
|
quantize_weight_bits, quantizable_op_type, threshold_rate):
|
||||||
|
|
||||||
|
model_dir = self.download_model(model_name, model_data_url,
|
||||||
|
model_data_md5)
|
||||||
|
|
||||||
|
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
|
||||||
|
save_model_dir = os.path.join(
|
||||||
|
os.getcwd(),
|
||||||
|
model_name + "_wq_" + str(quantize_weight_bits) + "_" + timestamp)
|
||||||
|
weight_quant = WeightQuantization(model_dir=model_dir + "/model")
|
||||||
|
weight_quant.quantize_weight_to_int(
|
||||||
|
save_model_dir=save_model_dir,
|
||||||
|
quantize_weight_bits=quantize_weight_bits,
|
||||||
|
quantizable_op_type=quantizable_op_type,
|
||||||
|
threshold_rate=threshold_rate)
|
||||||
|
print("finish weight quantization for " + model_name + "\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.system("rm -rf {}".format(save_model_dir))
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to delete {} due to {}".format(save_model_dir, str(
|
||||||
|
e)))
|
||||||
|
|
||||||
|
|
||||||
|
class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
|
||||||
|
model_name = "mobilenetv1"
|
||||||
|
model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz"
|
||||||
|
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
|
||||||
|
|
||||||
|
def test_weight_quantization_mobilenetv1_8bit(self):
|
||||||
|
quantize_weight_bits = 8
|
||||||
|
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
|
||||||
|
threshold_rate = 0.0
|
||||||
|
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
|
||||||
|
quantize_weight_bits, quantizable_op_type, threshold_rate)
|
||||||
|
|
||||||
|
def test_weight_quantization_mobilenetv1_16bit(self):
|
||||||
|
quantize_weight_bits = 16
|
||||||
|
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
|
||||||
|
threshold_rate = 1e-9
|
||||||
|
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
|
||||||
|
quantize_weight_bits, quantizable_op_type, threshold_rate)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in new issue