|
|
|
@ -24,8 +24,8 @@ import time
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid.framework import IrGraph
|
|
|
|
|
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
|
|
|
|
|
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
|
|
|
|
|
from paddle.fluid.contrib.slim.quantization import QatInt8MkldnnPass
|
|
|
|
|
from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
|
|
|
@ -53,10 +53,6 @@ def parse_args():
|
|
|
|
|
action='store_true',
|
|
|
|
|
help='If used, the QAT model is treated as a second generation model for performance optimization.'
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--save_model',
|
|
|
|
|
action='store_true',
|
|
|
|
|
help='If used, the QAT model will be saved after all transformations')
|
|
|
|
|
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--batch_num',
|
|
|
|
@ -68,15 +64,20 @@ def parse_args():
|
|
|
|
|
type=float,
|
|
|
|
|
default=0.01,
|
|
|
|
|
help='Accepted accuracy difference threshold.')
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--quantized_ops',
|
|
|
|
|
type=str,
|
|
|
|
|
default='',
|
|
|
|
|
help='A comma separated list of quantized operators.')
|
|
|
|
|
|
|
|
|
|
test_args, args = parser.parse_known_args(namespace=unittest)
|
|
|
|
|
|
|
|
|
|
return test_args, sys.argv[:1] + args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestQatInt8Comparison(unittest.TestCase):
|
|
|
|
|
class QatInt8ImageClassificationComparisonTest(unittest.TestCase):
|
|
|
|
|
"""
|
|
|
|
|
Test for accuracy comparison of QAT FP32 and INT8 inference.
|
|
|
|
|
Test for accuracy comparison of QAT FP32 and INT8 Image Classification inference.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _reader_creator(self, data_file='data.bin'):
|
|
|
|
@ -182,14 +183,15 @@ class TestQatInt8Comparison(unittest.TestCase):
|
|
|
|
|
graph.draw('.', 'qat_orig', graph.all_op_nodes())
|
|
|
|
|
if (transform_to_int8):
|
|
|
|
|
if (test_case_args.qat2):
|
|
|
|
|
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
|
|
|
|
|
transform_to_mkldnn_int8_pass = Qat2Int8MkldnnPass(
|
|
|
|
|
self._quantized_ops,
|
|
|
|
|
_scope=inference_scope,
|
|
|
|
|
_place=place,
|
|
|
|
|
_core=core,
|
|
|
|
|
_debug=self._debug)
|
|
|
|
|
graph = transform_to_mkldnn_int8_pass.apply(graph)
|
|
|
|
|
else:
|
|
|
|
|
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
|
|
|
|
|
mkldnn_int8_pass = QatInt8MkldnnPass(
|
|
|
|
|
_scope=inference_scope, _place=place)
|
|
|
|
|
graph = mkldnn_int8_pass.apply(graph)
|
|
|
|
|
|
|
|
|
@ -256,12 +258,6 @@ class TestQatInt8Comparison(unittest.TestCase):
|
|
|
|
|
_logger.info('Total inference run time: {:.2f} s'.format(
|
|
|
|
|
infer_total_time))
|
|
|
|
|
|
|
|
|
|
if test_case_args.save_model:
|
|
|
|
|
with fluid.scope_guard(inference_scope):
|
|
|
|
|
fluid.io.save_inference_model(
|
|
|
|
|
'transformed_qat_int8_model', feed_target_names,
|
|
|
|
|
fetch_targets, exe, inference_program)
|
|
|
|
|
|
|
|
|
|
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
|
|
|
|
|
|
|
|
|
|
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat):
|
|
|
|
@ -298,6 +294,7 @@ class TestQatInt8Comparison(unittest.TestCase):
|
|
|
|
|
skip_batch_num = test_case_args.skip_batch_num
|
|
|
|
|
acc_diff_threshold = test_case_args.acc_diff_threshold
|
|
|
|
|
self._debug = test_case_args.debug
|
|
|
|
|
self._quantized_ops = set(test_case_args.quantized_ops.split(','))
|
|
|
|
|
|
|
|
|
|
_logger.info('QAT FP32 & INT8 prediction run.')
|
|
|
|
|
_logger.info('QAT model: {0}'.format(qat_model_path))
|
|
|
|
@ -305,6 +302,7 @@ class TestQatInt8Comparison(unittest.TestCase):
|
|
|
|
|
_logger.info('Batch size: {0}'.format(batch_size))
|
|
|
|
|
_logger.info('Batch number: {0}'.format(batch_num))
|
|
|
|
|
_logger.info('Accuracy drop threshold: {0}.'.format(acc_diff_threshold))
|
|
|
|
|
_logger.info('Quantized ops: {0}.'.format(self._quantized_ops))
|
|
|
|
|
|
|
|
|
|
_logger.info('--- QAT FP32 prediction start ---')
|
|
|
|
|
val_reader = paddle.batch(
|