|
|
@ -48,9 +48,11 @@ def parse_args():
|
|
|
|
parser.add_argument(
|
|
|
|
parser.add_argument(
|
|
|
|
'--qat_model', type=str, default='', help='A path to a QAT model.')
|
|
|
|
'--qat_model', type=str, default='', help='A path to a QAT model.')
|
|
|
|
parser.add_argument(
|
|
|
|
parser.add_argument(
|
|
|
|
'--save_model',
|
|
|
|
'--fp32_model',
|
|
|
|
action='store_true',
|
|
|
|
type=str,
|
|
|
|
help='If used, the QAT model will be saved after all transformations')
|
|
|
|
default='',
|
|
|
|
|
|
|
|
help='A path to an FP32 model. If empty, the QAT model will be used for FP32 inference.'
|
|
|
|
|
|
|
|
)
|
|
|
|
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
|
|
|
|
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
|
|
|
|
parser.add_argument(
|
|
|
|
parser.add_argument(
|
|
|
|
'--labels', type=str, default='', help='File with labels.')
|
|
|
|
'--labels', type=str, default='', help='File with labels.')
|
|
|
@ -240,7 +242,10 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
qat_model_path = test_case_args.qat_model
|
|
|
|
qat_model_path = test_case_args.qat_model
|
|
|
|
|
|
|
|
assert qat_model_path, 'The QAT model path cannot be empty. Please, use the --qat_model option.'
|
|
|
|
|
|
|
|
fp32_model_path = test_case_args.fp32_model if test_case_args.fp32_model else qat_model_path
|
|
|
|
data_path = test_case_args.infer_data
|
|
|
|
data_path = test_case_args.infer_data
|
|
|
|
|
|
|
|
assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.'
|
|
|
|
labels_path = test_case_args.labels
|
|
|
|
labels_path = test_case_args.labels
|
|
|
|
batch_size = test_case_args.batch_size
|
|
|
|
batch_size = test_case_args.batch_size
|
|
|
|
batch_num = test_case_args.batch_num
|
|
|
|
batch_num = test_case_args.batch_num
|
|
|
@ -251,6 +256,7 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
_logger.info('QAT FP32 & INT8 prediction run.')
|
|
|
|
_logger.info('QAT FP32 & INT8 prediction run.')
|
|
|
|
_logger.info('QAT model: {0}'.format(qat_model_path))
|
|
|
|
_logger.info('QAT model: {0}'.format(qat_model_path))
|
|
|
|
|
|
|
|
_logger.info('FP32 model: {0}'.format(fp32_model_path))
|
|
|
|
_logger.info('Dataset: {0}'.format(data_path))
|
|
|
|
_logger.info('Dataset: {0}'.format(data_path))
|
|
|
|
_logger.info('Labels: {0}'.format(labels_path))
|
|
|
|
_logger.info('Labels: {0}'.format(labels_path))
|
|
|
|
_logger.info('Batch size: {0}'.format(batch_size))
|
|
|
|
_logger.info('Batch size: {0}'.format(batch_size))
|
|
|
@ -263,11 +269,12 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
|
|
|
|
self._reader_creator(data_path, labels_path), batch_size=batch_size)
|
|
|
|
self._reader_creator(data_path, labels_path), batch_size=batch_size)
|
|
|
|
fp32_acc, fp32_pps, fp32_lat = self._predict(
|
|
|
|
fp32_acc, fp32_pps, fp32_lat = self._predict(
|
|
|
|
val_reader,
|
|
|
|
val_reader,
|
|
|
|
qat_model_path,
|
|
|
|
fp32_model_path,
|
|
|
|
batch_size,
|
|
|
|
batch_size,
|
|
|
|
batch_num,
|
|
|
|
batch_num,
|
|
|
|
skip_batch_num,
|
|
|
|
skip_batch_num,
|
|
|
|
transform_to_int8=False)
|
|
|
|
transform_to_int8=False)
|
|
|
|
|
|
|
|
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc))
|
|
|
|
_logger.info('--- QAT INT8 prediction start ---')
|
|
|
|
_logger.info('--- QAT INT8 prediction start ---')
|
|
|
|
val_reader = paddle.batch(
|
|
|
|
val_reader = paddle.batch(
|
|
|
|
self._reader_creator(data_path, labels_path), batch_size=batch_size)
|
|
|
|
self._reader_creator(data_path, labels_path), batch_size=batch_size)
|
|
|
@ -278,6 +285,7 @@ class QatInt8NLPComparisonTest(unittest.TestCase):
|
|
|
|
batch_num,
|
|
|
|
batch_num,
|
|
|
|
skip_batch_num,
|
|
|
|
skip_batch_num,
|
|
|
|
transform_to_int8=True)
|
|
|
|
transform_to_int8=True)
|
|
|
|
|
|
|
|
_logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc))
|
|
|
|
|
|
|
|
|
|
|
|
self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat)
|
|
|
|
self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat)
|
|
|
|
self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold)
|
|
|
|
self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold)
|
|
|
|