INT8 MKL-DNN v2 integrate to slim (#17634)
* refactor PR 16865 * delete mergetool files * test=develop * test=develop * test=develop * test=develop * create dir for int8 model before call SaveOptimModel * test=develop * mkldnn int8 only support linux; test=develop * refine code; test=develop * remove comment; test=develop * refine code; test=develop * fix bug; test=develop * add exception for mkldnn_post_training_strategy * reuse int8v2 CAPI dataset; test=develop * fix accuracy check bug; test=develop * remove tab * convert files to unix format * test=develop * reduce CI time;test=develop * reduce CI time and refine code;test=develop * refine comment; test=develop * add cmake FLAGS;test=develop * remove predict_num;test=developdependabot/pip/python/requests-2.20.0
parent
6a1df46991
commit
993c703bcc
@ -0,0 +1,120 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import logging
|
||||
import six
|
||||
import numpy as np
|
||||
from .... import core
|
||||
from ..core.strategy import Strategy
|
||||
|
||||
__all__ = ['MKLDNNPostTrainingQuantStrategy']
|
||||
|
||||
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class MKLDNNPostTrainingQuantStrategy(Strategy):
|
||||
"""
|
||||
The strategy for MKL-DNN Post Training quantization strategy.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
int8_model_save_path=None,
|
||||
fp32_model_path=None,
|
||||
cpu_math_library_num_threads=1):
|
||||
"""
|
||||
Args:
|
||||
int8_model_save_path(str): int8_model_save_path is used to save an int8 ProgramDesc
|
||||
with fp32 weights which is used for MKL-DNN int8 inference. For post training quantization,
|
||||
MKLDNNPostTrainingQuantStrategy only supports converting a fp32 ProgramDesc
|
||||
with fp32 weights to an int8 ProgramDesc with fp32 weights now. The saved
|
||||
int8 ProgramDesc with fp32 weights only can be executed with MKL-DNN enabled.
|
||||
None means it doesn't save int8 ProgramDesc with fp32 weights. default: None.
|
||||
fp32_model_path(str): fp32_model_path is used to load an original fp32 ProgramDesc with fp32 weights.
|
||||
None means it doesn't have a fp32 ProgramDesc with fp32 weights. default: None.
|
||||
cpu_math_library_num_threads(int): The number of cpu math library threads which is used on
|
||||
MKLDNNPostTrainingQuantStrategy. 1 means it only uses one cpu math library
|
||||
thread. default: 1
|
||||
"""
|
||||
|
||||
super(MKLDNNPostTrainingQuantStrategy, self).__init__(0, 0)
|
||||
self.int8_model_save_path = int8_model_save_path
|
||||
if fp32_model_path is None:
|
||||
raise Exception("fp32_model_path is None")
|
||||
self.fp32_model_path = fp32_model_path
|
||||
self.cpu_math_library_num_threads = cpu_math_library_num_threads
|
||||
|
||||
def on_compression_begin(self, context):
|
||||
"""
|
||||
Prepare the data and quantify the model
|
||||
"""
|
||||
|
||||
super(MKLDNNPostTrainingQuantStrategy,
|
||||
self).on_compression_begin(context)
|
||||
_logger.info('InferQuantStrategy::on_compression_begin')
|
||||
|
||||
# Prepare the Analysis Config
|
||||
infer_config = core.AnalysisConfig("AnalysisConfig")
|
||||
infer_config.switch_ir_optim(True)
|
||||
infer_config.disable_gpu()
|
||||
infer_config.set_model(self.fp32_model_path)
|
||||
infer_config.enable_mkldnn()
|
||||
infer_config.set_cpu_math_library_num_threads(
|
||||
self.cpu_math_library_num_threads)
|
||||
|
||||
# Prepare the data for calculating the quantization scales
|
||||
warmup_reader = context.eval_reader()
|
||||
if six.PY2:
|
||||
data = warmup_reader.next()
|
||||
|
||||
if six.PY3:
|
||||
data = warmup_reader.__next__()
|
||||
|
||||
# TODO (Intel) Remove limits that MKLDNNPostTrainingQuantStrategy
|
||||
# only support image classification
|
||||
num_images = len(data)
|
||||
images = core.PaddleTensor()
|
||||
images.name = "x"
|
||||
images.shape = [num_images, ] + list(data[0][0].shape)
|
||||
images.dtype = core.PaddleDType.FLOAT32
|
||||
image_data = [img.tolist() for (img, _) in data]
|
||||
image_data = np.array(image_data).astype("float32")
|
||||
image_data = image_data.ravel()
|
||||
images.data = core.PaddleBuf(image_data.tolist())
|
||||
|
||||
labels = core.PaddleTensor()
|
||||
labels.name = "y"
|
||||
labels.shape = [num_images, 1]
|
||||
labels.dtype = core.PaddleDType.INT64
|
||||
label_data = [label for (_, label) in data]
|
||||
labels.data = core.PaddleBuf(label_data)
|
||||
|
||||
warmup_data = [images, labels]
|
||||
|
||||
# Enable the INT8 Quantization
|
||||
infer_config.enable_quantizer()
|
||||
infer_config.quantizer_config().set_quant_data(warmup_data)
|
||||
infer_config.quantizer_config().set_quant_batch_size(num_images)
|
||||
|
||||
# Run INT8 MKL-DNN Quantization
|
||||
predictor = core.create_paddle_predictor(infer_config)
|
||||
if self.int8_model_save_path:
|
||||
if not os.path.exists(self.int8_model_save_path):
|
||||
os.makedirs(self.int8_model_save_path)
|
||||
predictor.SaveOptimModel(self.int8_model_save_path)
|
||||
|
||||
_logger.info(
|
||||
'Finish MKLDNNPostTrainingQuantStrategy::on_compresseion_begin')
|
@ -1,11 +1,63 @@
|
||||
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
|
||||
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
|
||||
|
||||
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
|
||||
py_test(${target} SRCS ${filename}
|
||||
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
|
||||
ARGS --infer_model ${model_dir}/model
|
||||
--infer_data ${data_dir}/data.bin
|
||||
--int8_model_save_path int8_models/${target}
|
||||
--warmup_batch_size 100
|
||||
--batch_size 50)
|
||||
endfunction()
|
||||
|
||||
# NOTE: TODOOOOOOOOOOO
|
||||
# temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs
|
||||
# Need to figure out the root cause and then add it back
|
||||
list(REMOVE_ITEM TEST_OPS test_distillation_strategy)
|
||||
|
||||
# int8 image classification python api test
|
||||
if(LINUX AND WITH_MKLDNN)
|
||||
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
|
||||
set(MKLDNN_INT8_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
|
||||
|
||||
# googlenet int8
|
||||
set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet")
|
||||
inference_analysis_python_api_int8_test(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
|
||||
|
||||
# mobilenet int8
|
||||
set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenet")
|
||||
inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
|
||||
|
||||
# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
|
||||
# since the following UTs cost too much time on CI test.
|
||||
if (WITH_SLIM_MKLDNN_FULL_TEST)
|
||||
# resnet50 int8
|
||||
set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50")
|
||||
inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
|
||||
|
||||
# mobilenetv2 int8
|
||||
set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2")
|
||||
inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
|
||||
|
||||
# resnet101 int8
|
||||
set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101")
|
||||
inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
|
||||
|
||||
# vgg16 int8
|
||||
set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16")
|
||||
inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
|
||||
|
||||
# vgg19 int8
|
||||
set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19")
|
||||
inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Since test_mkldnn_int8_quantization_strategy only supports testing on Linux
|
||||
# with MKL-DNN, we remove it here for not repeating test, or not testing on other systems.
|
||||
list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy)
|
||||
|
||||
foreach(src ${TEST_OPS})
|
||||
py_test(${src} SRCS ${src}.py)
|
||||
endforeach()
|
||||
|
@ -0,0 +1,28 @@
|
||||
#int8_model_save_path(str): int8_model_save_path is used to save an int8 ProgramDesc with
|
||||
# fp32 weights which is used for MKL-DNN int8 inference. For post training quantization,
|
||||
# MKLDNNPostTrainingQuantStrategy only supports converting a fp32 ProgramDesc
|
||||
# with fp32 weights to an int8 ProgramDesc with fp32 weights now. The saved
|
||||
# int8 ProgramDesc with fp32 weights only can be executed with MKL-DNN enabled.
|
||||
# None means it doesn't save int8 ProgramDesc with fp32 weights. default: None.
|
||||
#
|
||||
#fp32_model_path(str): fp32_model_path is used to load an original fp32 ProgramDesc with fp32 weights.
|
||||
# None means it doesn't have a fp32 ProgramDesc with fp32 weights. default: None.
|
||||
#
|
||||
#cpu_math_library_num_threads(int): The number of cpu math library threads which is used on
|
||||
# MKLDNNPostTrainingQuantStrategy. 1 means it only uses one cpu math library
|
||||
# thread. default: 1
|
||||
# Note: Here we set the cpu_math_library_num_threads to 4 which is the maximum number of
|
||||
# cpu math library threads on CI machine.
|
||||
#
|
||||
version: 1.0
|
||||
strategies:
|
||||
mkldnn_post_training_strategy:
|
||||
class: 'MKLDNNPostTrainingQuantStrategy'
|
||||
int8_model_save_path: 'OUTPUT_PATH'
|
||||
fp32_model_path: 'MODEL_PATH'
|
||||
cpu_math_library_num_threads: 4
|
||||
compressor:
|
||||
epoch: 0
|
||||
checkpoint_path: ''
|
||||
strategies:
|
||||
- mkldnn_post_training_strategy
|
@ -0,0 +1,216 @@
|
||||
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
|
||||
#
|
||||
# licensed under the apache license, version 2.0 (the "license");
|
||||
# you may not use this file except in compliance with the license.
|
||||
# you may obtain a copy of the license at
|
||||
#
|
||||
# http://www.apache.org/licenses/license-2.0
|
||||
#
|
||||
# unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the license is distributed on an "as is" basis,
|
||||
# without warranties or conditions of any kind, either express or implied.
|
||||
# see the license for the specific language governing permissions and
|
||||
# limitations under the license.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import shutil
|
||||
import logging
|
||||
import struct
|
||||
import six
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.contrib.slim.core import Compressor
|
||||
|
||||
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
|
||||
parser.add_argument(
|
||||
'--infer_model',
|
||||
type=str,
|
||||
default='',
|
||||
help='infer_model is used to load an original fp32 ProgramDesc with fp32 weights'
|
||||
)
|
||||
parser.add_argument('--infer_data', type=str, default='', help='data file')
|
||||
parser.add_argument(
|
||||
'--int8_model_save_path',
|
||||
type=str,
|
||||
default='./output',
|
||||
help='infer_data is used to save an int8 ProgramDesc with fp32 weights')
|
||||
parser.add_argument(
|
||||
'--warmup_batch_size',
|
||||
type=int,
|
||||
default=100,
|
||||
help='batch size for quantization warmup')
|
||||
parser.add_argument(
|
||||
'--accuracy_diff_threshold',
|
||||
type=float,
|
||||
default=0.01,
|
||||
help='accepted accuracy drop threshold.')
|
||||
|
||||
test_args, args = parser.parse_known_args(namespace=unittest)
|
||||
|
||||
return test_args, sys.argv[:1] + args
|
||||
|
||||
|
||||
class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
|
||||
"""
|
||||
Test API of Post Training quantization strategy for int8 with MKL-DNN.
|
||||
"""
|
||||
|
||||
def _reader_creator(self, data_file='data.bin', cycle=False):
|
||||
def reader():
|
||||
with open(data_file, 'rb') as fp:
|
||||
num = fp.read(8)
|
||||
num = struct.unpack('q', num)[0]
|
||||
imgs_offset = 8
|
||||
img_ch = 3
|
||||
img_w = 224
|
||||
img_h = 224
|
||||
img_pixel_size = 4
|
||||
img_size = img_ch * img_h * img_w * img_pixel_size
|
||||
label_size = 8
|
||||
labels_offset = imgs_offset + num * img_size
|
||||
step = 0
|
||||
|
||||
while step < num:
|
||||
fp.seek(imgs_offset + img_size * step)
|
||||
img = fp.read(img_size)
|
||||
img = struct.unpack_from('{}f'.format(img_ch * img_w *
|
||||
img_h), img)
|
||||
img = np.array(img)
|
||||
img.shape = (img_ch, img_w, img_h)
|
||||
fp.seek(labels_offset + label_size * step)
|
||||
label = fp.read(label_size)
|
||||
label = struct.unpack('q', label)[0]
|
||||
yield img, int(label)
|
||||
step += 1
|
||||
if cycle and step == num:
|
||||
step = 0
|
||||
|
||||
return reader
|
||||
|
||||
def _update_config_file(self, fp32_model_path, output_path):
|
||||
config_path = './quantization/config_mkldnn_int8.yaml'
|
||||
new_config_path = './quantization/temp.yaml'
|
||||
shutil.copy(config_path, new_config_path)
|
||||
|
||||
with open(new_config_path, 'r+') as fp:
|
||||
data = fp.read()
|
||||
data = data.replace('MODEL_PATH', fp32_model_path)
|
||||
data = data.replace('OUTPUT_PATH', output_path)
|
||||
with open(new_config_path, 'w') as fp:
|
||||
fp.write(data)
|
||||
|
||||
return new_config_path
|
||||
|
||||
def _predict(self, test_reader=None, model_path=None):
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
inference_scope = fluid.executor.global_scope()
|
||||
with fluid.scope_guard(inference_scope):
|
||||
if os.path.exists(os.path.join(model_path, '__model__')):
|
||||
[inference_program, feed_target_names,
|
||||
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
|
||||
else:
|
||||
[inference_program, feed_target_names,
|
||||
fetch_targets] = fluid.io.load_inference_model(
|
||||
model_path, exe, 'model', 'params')
|
||||
|
||||
dshape = [3, 224, 224]
|
||||
top1 = 0.0
|
||||
top5 = 0.0
|
||||
total_samples = 0
|
||||
for _, data in enumerate(test_reader()):
|
||||
if six.PY2:
|
||||
images = map(lambda x: x[0].reshape(dshape), data)
|
||||
if six.PY3:
|
||||
images = list(map(lambda x: x[0].reshape(dshape), data))
|
||||
images = np.array(images).astype('float32')
|
||||
labels = np.array([x[1] for x in data]).astype("int64")
|
||||
labels = labels.reshape([-1, 1])
|
||||
out = exe.run(inference_program,
|
||||
feed={
|
||||
feed_target_names[0]: images,
|
||||
feed_target_names[1]: labels
|
||||
},
|
||||
fetch_list=fetch_targets)
|
||||
top1 += np.sum(out[1]) * len(data)
|
||||
top5 += np.sum(out[2]) * len(data)
|
||||
total_samples += len(data)
|
||||
return top1 / total_samples, top5 / total_samples
|
||||
|
||||
def _warmup(self, reader=None, config_path=''):
|
||||
com_pass = Compressor(
|
||||
place=None,
|
||||
scope=None,
|
||||
train_program=None,
|
||||
train_reader=None,
|
||||
train_feed_list=[],
|
||||
train_fetch_list=[],
|
||||
eval_program=None,
|
||||
eval_reader=reader,
|
||||
eval_feed_list=[],
|
||||
eval_fetch_list=[],
|
||||
teacher_programs=[],
|
||||
checkpoint_path='',
|
||||
train_optimizer=None,
|
||||
distiller_optimizer=None)
|
||||
com_pass.config(config_path)
|
||||
com_pass.run()
|
||||
|
||||
def test_compression(self):
|
||||
if not fluid.core.is_compiled_with_mkldnn():
|
||||
return
|
||||
|
||||
int8_model_path = test_case_args.int8_model_save_path
|
||||
data_path = test_case_args.infer_data
|
||||
fp32_model_path = test_case_args.infer_model
|
||||
batch_size = test_case_args.batch_size
|
||||
|
||||
warmup_batch_size = test_case_args.warmup_batch_size
|
||||
accuracy_diff_threshold = test_case_args.accuracy_diff_threshold
|
||||
|
||||
_logger.info(
|
||||
'FP32 & INT8 prediction run: batch_size {0}, warmup batch size {1}.'.
|
||||
format(batch_size, warmup_batch_size))
|
||||
|
||||
#warmup dataset, only use the first batch data
|
||||
warmup_reader = paddle.batch(
|
||||
self._reader_creator(data_path, False),
|
||||
batch_size=warmup_batch_size)
|
||||
config_path = self._update_config_file(fp32_model_path, int8_model_path)
|
||||
self._warmup(warmup_reader, config_path)
|
||||
|
||||
_logger.info('--- INT8 prediction start ---')
|
||||
val_reader = paddle.batch(
|
||||
self._reader_creator(data_path, False), batch_size=batch_size)
|
||||
int8_model_result = self._predict(val_reader, int8_model_path)
|
||||
_logger.info('--- FP32 prediction start ---')
|
||||
val_reader = paddle.batch(
|
||||
self._reader_creator(data_path, False), batch_size=batch_size)
|
||||
fp32_model_result = self._predict(val_reader, fp32_model_path)
|
||||
|
||||
_logger.info('--- comparing outputs ---')
|
||||
_logger.info('Avg top1 INT8 accuracy: {0:.4f}'.format(int8_model_result[
|
||||
0]))
|
||||
_logger.info('Avg top1 FP32 accuracy: {0:.4f}'.format(fp32_model_result[
|
||||
0]))
|
||||
_logger.info('Accepted accuracy drop threshold: {0}'.format(
|
||||
accuracy_diff_threshold))
|
||||
assert fp32_model_result[0] - int8_model_result[
|
||||
0] <= accuracy_diff_threshold
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
global test_case_args
|
||||
test_case_args, remaining_args = parse_args()
|
||||
unittest.main(argv=remaining_args)
|
Loading…
Reference in new issue