add resnet50 test for post trainint quantization, test=develop (#21272)

revert-21172-masked_select_api
juncaipeng 5 years ago committed by Tao Luo
parent 9a7832f8be
commit 84865b806b

@ -99,7 +99,7 @@ class PostTrainingQuantization(object):
params_filename = None params_filename = None
save_model_path = path/to/save_model_path save_model_path = path/to/save_model_path
# prepare the sample generator according to the model, and the # prepare the sample generator according to the model, and the
# sample generator must return a simple every time. The reference # sample generator must return a sample every time. The reference
# document: https://www.paddlepaddle.org.cn/documentation/docs/zh # document: https://www.paddlepaddle.org.cn/documentation/docs/zh
# /user_guides/howto/prepare_data/use_py_reader.html # /user_guides/howto/prepare_data/use_py_reader.html
sample_generator = your_sample_generator sample_generator = your_sample_generator

@ -48,7 +48,8 @@ endfunction()
if(WIN32) if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
endif() endif()
# int8 image classification python api test # int8 image classification python api test

@ -110,10 +110,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.int8_download = 'int8/download' self.int8_download = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
self.int8_download) self.int8_download)
self.data_cache_folder = ''
data_urls = [] data_urls = []
data_md5s = [] data_md5s = []
self.data_cache_folder = ''
if os.environ.get('DATASET') == 'full': if os.environ.get('DATASET') == 'full':
data_urls.append( data_urls.append(
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa' 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa'
@ -145,7 +144,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
'DATASET') == 'full' else 1 'DATASET') == 'full' else 1
self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
self.int8_model = '' self.int8_model = os.path.join(os.getcwd(),
"post_training_" + self.timestamp)
def tearDown(self): def tearDown(self):
try: try:
@ -191,14 +191,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
def download_model(self): def download_model(self):
pass pass
def run_program(self, model_path): def run_program(self, model_path, batch_size, infer_iterations):
image_shape = [3, 224, 224] image_shape = [3, 224, 224]
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
[infer_program, feed_dict, fetch_targets] = \ [infer_program, feed_dict, fetch_targets] = \
fluid.io.load_inference_model(model_path, exe) fluid.io.load_inference_model(model_path, exe)
val_reader = paddle.batch(val(), self.batch_size) val_reader = paddle.batch(val(), batch_size)
iterations = self.infer_iterations iterations = infer_iterations
test_info = [] test_info = []
cnt = 0 cnt = 0
@ -237,8 +237,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path, model_path,
algo="KL", algo="KL",
is_full_quantize=False): is_full_quantize=False):
self.int8_model = os.path.join(os.getcwd(),
"post_training_" + self.timestamp)
try: try:
os.system("mkdir " + self.int8_model) os.system("mkdir " + self.int8_model)
except Exception as e: except Exception as e:
@ -264,52 +262,50 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): model_cache_folder = self.download_data(data_urls, data_md5s, model)
def download_model(self):
# mobilenetv1 fp32 data
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
self.model_cache_folder = self.download_data(data_urls, data_md5s,
"mobilenetv1_fp32")
self.model = "MobileNet-V1"
self.algo = "KL"
def test_post_training_mobilenetv1(self):
self.download_model()
print("Start FP32 inference for {0} on {1} images ...".format( print("Start FP32 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size)) model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, (fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
fp32_acc1) = self.run_program(self.model_cache_folder + "/model") model_cache_folder + "/model", batch_size, infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(self.model, self.sample_iterations * self.batch_size)) format(model, sample_iterations * batch_size))
self.generate_quantized_model( self.generate_quantized_model(
self.model_cache_folder + "/model", model_cache_folder + "/model", algo=algo, is_full_quantize=True)
algo=self.algo,
is_full_quantize=True)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size)) model, infer_iterations * batch_size))
(int8_throughput, int8_latency, (int8_throughput, int8_latency, int8_acc1) = self.run_program(
int8_acc1) = self.run_program(self.int8_model) self.int8_model, batch_size, infer_iterations)
print( print(
"FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". "FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
format(self.model, self.batch_size, fp32_throughput, fp32_latency, format(model, batch_size, fp32_throughput, fp32_latency, fp32_acc1))
fp32_acc1))
print( print(
"INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". "INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
format(self.model, self.batch_size, int8_throughput, int8_latency, format(model, batch_size, int8_throughput, int8_latency, int8_acc1))
int8_acc1))
sys.stdout.flush() sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1 delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, 0.025) self.assertLess(delta_value, 0.025)
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_mobilenetv1(self):
model = "MobileNet-V1"
algo = "KL"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
self.run_test(model, algo, data_urls, data_md5s)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

@ -0,0 +1,32 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantization
class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self):
model = "ResNet-50"
algo = "direct"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
data_md5s = ['4a5194524823d9b76da6e738e1367881']
self.run_test(model, algo, data_urls, data_md5s)
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save