Add support for Ernie NLP model to the Slim QAT (#22506)

* a test for Ernie QAT INT8 accuracy check

test=develop

* Remove NLP comparison test to split PRs

test=develop

* Fix typo and tabs, delete commented lines

test=develop

* re-combine the 2 PRs, test=develop

Co-authored-by: Michał Gallus <sand3r@interia.eu>
Co-authored-by: bingyanghuang <33643817+bingyanghuang@users.noreply.github.com>
revert-22710-feature/integrated_ps_api
Wojciech Uss 5 years ago committed by GitHub
parent 5a1a9a1e59
commit 4cddb43c5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -226,15 +226,18 @@ if(WITH_MKLDNN)
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
### Image classification tests
set(IMAGENET_DATA_PATH "${INT8_DATA_DIR}/data.bin")
set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification")
set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc")
## Image classification models
# download dataset if necessary
download_int8_data(${INT8_DATA_DIR} "imagenet_val_100_tail.tar.gz")
# ImageNet small dataset
# May be already downloaded for INT8 QAT unit tests
set(IMAGENET_DATA_ARCHIVE "imagenet_val_100_tail.tar.gz")
set(IMAGENET_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/imagenet")
set(IMAGENET_DATA_PATH "${IMAGENET_DATA_DIR}/data.bin")
download_int8_data(${IMAGENET_DATA_DIR} ${IMAGENET_DATA_ARCHIVE})
# build test binary to be used in subsequent tests
set(INT8_IMG_CLASS_TEST_APP "test_analyzer_int8_image_classification")
set(INT8_IMG_CLASS_TEST_APP_SRC "analyzer_int8_image_classification_tester.cc")
inference_analysis_api_test_build(${INT8_IMG_CLASS_TEST_APP} ${INT8_IMG_CLASS_TEST_APP_SRC})
# resnet50 int8
@ -296,7 +299,7 @@ if(WITH_MKLDNN)
### optimized FP32 vs. QAT INT8 tests
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/qat")
set(QAT_IMG_CLASS_TEST_APP "test_analyzer_qat_image_classification")
set(QAT_IMG_CLASS_TEST_APP_SRC "analyzer_qat_image_classification_tester.cc")
@ -304,8 +307,8 @@ if(WITH_MKLDNN)
inference_analysis_api_test_build(${QAT_IMG_CLASS_TEST_APP} ${QAT_IMG_CLASS_TEST_APP_SRC})
# MobileNet FP32 vs. QAT INT8
# The FP32 model should already be downloaded for slim QAT unit tests
set(QAT2_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
download_qat_data(${QAT2_MobileNet_MODEL_DIR} "MobileNet_qat_perf.tar.gz")
set(QAT2_INT8_MobileNet_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf_int8")
download_qat_data(${QAT2_INT8_MobileNet_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_qat_test_run(test_analyzer_qat_performance_benchmark ${QAT_IMG_CLASS_TEST_APP} ${QAT2_MobileNet_MODEL_DIR}/MobileNet_qat_perf/float ${QAT2_INT8_MobileNet_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})

@ -479,7 +479,7 @@ GetPrimitiveFactory(const MKLDNNDeviceContext& dev_ctx,
const Tensor* weights,
const mkldnn::engine& mkldnn_engine) {
const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), input->format(),
platform::ThreadIDasStr(), input->format(), input->dims()[0],
framework::vectorize<int>(weights->dims()), ctx.OutputName("Out"));
auto prim_creator =

File diff suppressed because it is too large Load Diff

@ -24,8 +24,8 @@ import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
@ -53,10 +53,6 @@ def parse_args():
action='store_true',
help='If used, the QAT model is treated as a second generation model for performance optimization.'
)
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--batch_num',
@ -68,15 +64,20 @@ def parse_args():
type=float,
default=0.01,
help='Accepted accuracy difference threshold.')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class TestQatInt8Comparison(unittest.TestCase):
class QatInt8ImageClassificationComparisonTest(unittest.TestCase):
"""
Test for accuracy comparison of QAT FP32 and INT8 inference.
Test for accuracy comparison of QAT FP32 and INT8 Image Classification inference.
"""
def _reader_creator(self, data_file='data.bin'):
@ -182,14 +183,15 @@ class TestQatInt8Comparison(unittest.TestCase):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8):
if (test_case_args.qat2):
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
self._quantized_ops,
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
else:
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
mkldnn_int8_pass = QatInt8MkldnnPass(
_scope=inference_scope, _place=place)
graph = mkldnn_int8_pass.apply(graph)
@ -256,12 +258,6 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
if test_case_args.save_model:
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
'transformed_qat_int8_model', feed_target_names,
fetch_targets, exe, inference_program)
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat):
@ -298,6 +294,7 @@ class TestQatInt8Comparison(unittest.TestCase):
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
self._quantized_ops = set(test_case_args.quantized_ops.split(','))
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
@ -305,6 +302,7 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Batch size: {0}'.format(batch_size))
_logger.info('Batch number: {0}'.format(batch_num))
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
_logger.info('--- QAT FP32 prediction start ---')
val_reader = paddle.batch(

@ -24,7 +24,7 @@ import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
from paddle.fluid import core
@ -42,6 +42,11 @@ def parse_args():
type=str,
default='',
help='Saved optimized and quantized INT8 model')
parser.add_argument(
'--quantized_ops',
type=str,
default='',
help='A comma separated list of quantized operators.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
@ -60,8 +65,9 @@ def transform_and_save_model(original_path, save_path, save_type):
fetch_targets] = fluid.io.load_inference_model(original_path, exe,
'model', 'params')
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
_scope=inference_scope, _place=place, _core=core)
quantized_ops = set(test_args.quantized_ops.split(','))
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
quantized_ops, _scope=inference_scope, _place=place, _core=core)
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if save_type == 'FP32':

@ -22,7 +22,7 @@ import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
from paddle.fluid import core
os.environ["CPU_NUM"] = "1"
@ -149,8 +149,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
freeze_pass.apply(test_graph)
# Transform quantized graph for MKL-DNN INT8 inference
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=scope, _place=place)
mkldnn_int8_pass = QatInt8MkldnnPass(_scope=scope, _place=place)
mkldnn_int8_pass.apply(test_graph)
dev_name = '_cpu_'
if not for_ci:

Loading…
Cancel
Save