|
|
@ -68,6 +68,12 @@ def parse_args():
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default='',
|
|
|
|
default='',
|
|
|
|
help='A comma separated list of operator ids to skip in quantization.')
|
|
|
|
help='A comma separated list of operator ids to skip in quantization.')
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
|
|
'--targets',
|
|
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
|
|
default='quant,int8,fp32',
|
|
|
|
|
|
|
|
help='A comma separated list of inference types to run ("int8", "fp32", "quant"). Default: "quant,int8,fp32"'
|
|
|
|
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
parser.add_argument(
|
|
|
|
'--debug',
|
|
|
|
'--debug',
|
|
|
|
action='store_true',
|
|
|
|
action='store_true',
|
|
|
@ -310,6 +316,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
|
|
|
|
assert int8_acc1 > 0.5
|
|
|
|
assert int8_acc1 > 0.5
|
|
|
|
assert quant_acc1 - int8_acc1 <= threshold
|
|
|
|
assert quant_acc1 - int8_acc1 <= threshold
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _strings_from_csv(self, string):
|
|
|
|
|
|
|
|
return set(s.strip() for s in string.split(','))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ints_from_csv(self, string):
|
|
|
|
|
|
|
|
return set(map(int, string.split(',')))
|
|
|
|
|
|
|
|
|
|
|
|
def test_graph_transformation(self):
|
|
|
|
def test_graph_transformation(self):
|
|
|
|
if not fluid.core.is_compiled_with_mkldnn():
|
|
|
|
if not fluid.core.is_compiled_with_mkldnn():
|
|
|
|
return
|
|
|
|
return
|
|
|
@ -326,14 +338,19 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
|
|
|
|
self._debug = test_case_args.debug
|
|
|
|
self._debug = test_case_args.debug
|
|
|
|
|
|
|
|
|
|
|
|
self._quantized_ops = set()
|
|
|
|
self._quantized_ops = set()
|
|
|
|
if len(test_case_args.ops_to_quantize) > 0:
|
|
|
|
if test_case_args.ops_to_quantize:
|
|
|
|
self._quantized_ops = set(
|
|
|
|
self._quantized_ops = self._strings_from_csv(
|
|
|
|
op.strip() for op in test_case_args.ops_to_quantize.split(','))
|
|
|
|
test_case_args.ops_to_quantize)
|
|
|
|
|
|
|
|
|
|
|
|
self._op_ids_to_skip = set([-1])
|
|
|
|
self._op_ids_to_skip = set([-1])
|
|
|
|
if len(test_case_args.op_ids_to_skip) > 0:
|
|
|
|
if test_case_args.op_ids_to_skip:
|
|
|
|
self._op_ids_to_skip = set(
|
|
|
|
self._op_ids_to_skip = self._ints_from_csv(
|
|
|
|
map(int, test_case_args.op_ids_to_skip.split(',')))
|
|
|
|
test_case_args.op_ids_to_skip)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._targets = self._strings_from_csv(test_case_args.targets)
|
|
|
|
|
|
|
|
assert self._targets.intersection(
|
|
|
|
|
|
|
|
{'quant', 'int8', 'fp32'}
|
|
|
|
|
|
|
|
), 'The --targets option, if used, must contain at least one of the targets: "quant", "int8", "fp32".'
|
|
|
|
|
|
|
|
|
|
|
|
_logger.info('Quant & INT8 prediction run.')
|
|
|
|
_logger.info('Quant & INT8 prediction run.')
|
|
|
|
_logger.info('Quant model: {}'.format(quant_model_path))
|
|
|
|
_logger.info('Quant model: {}'.format(quant_model_path))
|
|
|
@ -348,35 +365,38 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
|
|
|
|
_logger.info('Op ids to skip quantization: {}.'.format(','.join(
|
|
|
|
_logger.info('Op ids to skip quantization: {}.'.format(','.join(
|
|
|
|
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
|
|
|
|
map(str, self._op_ids_to_skip)) if test_case_args.op_ids_to_skip
|
|
|
|
else 'none'))
|
|
|
|
else 'none'))
|
|
|
|
|
|
|
|
_logger.info('Targets: {}.'.format(','.join(self._targets)))
|
|
|
|
|
|
|
|
|
|
|
|
_logger.info('--- Quant prediction start ---')
|
|
|
|
if 'quant' in self._targets:
|
|
|
|
val_reader = paddle.batch(
|
|
|
|
_logger.info('--- Quant prediction start ---')
|
|
|
|
self._reader_creator(data_path), batch_size=batch_size)
|
|
|
|
val_reader = paddle.batch(
|
|
|
|
quant_output, quant_acc1, quant_acc5, quant_fps, quant_lat = self._predict(
|
|
|
|
self._reader_creator(data_path), batch_size=batch_size)
|
|
|
|
val_reader,
|
|
|
|
quant_output, quant_acc1, quant_acc5, quant_fps, quant_lat = self._predict(
|
|
|
|
quant_model_path,
|
|
|
|
val_reader,
|
|
|
|
batch_size,
|
|
|
|
quant_model_path,
|
|
|
|
batch_num,
|
|
|
|
batch_size,
|
|
|
|
skip_batch_num,
|
|
|
|
batch_num,
|
|
|
|
target='quant')
|
|
|
|
skip_batch_num,
|
|
|
|
self._print_performance('Quant', quant_fps, quant_lat)
|
|
|
|
target='quant')
|
|
|
|
self._print_accuracy('Quant', quant_acc1, quant_acc5)
|
|
|
|
self._print_performance('Quant', quant_fps, quant_lat)
|
|
|
|
|
|
|
|
self._print_accuracy('Quant', quant_acc1, quant_acc5)
|
|
|
|
|
|
|
|
|
|
|
|
_logger.info('--- INT8 prediction start ---')
|
|
|
|
if 'int8' in self._targets:
|
|
|
|
val_reader = paddle.batch(
|
|
|
|
_logger.info('--- INT8 prediction start ---')
|
|
|
|
self._reader_creator(data_path), batch_size=batch_size)
|
|
|
|
val_reader = paddle.batch(
|
|
|
|
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
|
|
|
|
self._reader_creator(data_path), batch_size=batch_size)
|
|
|
|
val_reader,
|
|
|
|
int8_output, int8_acc1, int8_acc5, int8_fps, int8_lat = self._predict(
|
|
|
|
quant_model_path,
|
|
|
|
val_reader,
|
|
|
|
batch_size,
|
|
|
|
quant_model_path,
|
|
|
|
batch_num,
|
|
|
|
batch_size,
|
|
|
|
skip_batch_num,
|
|
|
|
batch_num,
|
|
|
|
target='int8')
|
|
|
|
skip_batch_num,
|
|
|
|
self._print_performance('INT8', int8_fps, int8_lat)
|
|
|
|
target='int8')
|
|
|
|
self._print_accuracy('INT8', int8_acc1, int8_acc5)
|
|
|
|
self._print_performance('INT8', int8_fps, int8_lat)
|
|
|
|
|
|
|
|
self._print_accuracy('INT8', int8_acc1, int8_acc5)
|
|
|
|
|
|
|
|
|
|
|
|
fp32_acc1 = fp32_acc5 = fp32_fps = fp32_lat = -1
|
|
|
|
fp32_acc1 = fp32_acc5 = fp32_fps = fp32_lat = -1
|
|
|
|
if fp32_model_path:
|
|
|
|
if 'fp32' in self._targets and fp32_model_path:
|
|
|
|
_logger.info('--- FP32 prediction start ---')
|
|
|
|
_logger.info('--- FP32 prediction start ---')
|
|
|
|
val_reader = paddle.batch(
|
|
|
|
val_reader = paddle.batch(
|
|
|
|
self._reader_creator(data_path), batch_size=batch_size)
|
|
|
|
self._reader_creator(data_path), batch_size=batch_size)
|
|
|
@ -390,10 +410,12 @@ class Quant2Int8ImageClassificationComparisonTest(unittest.TestCase):
|
|
|
|
self._print_performance('FP32', fp32_fps, fp32_lat)
|
|
|
|
self._print_performance('FP32', fp32_fps, fp32_lat)
|
|
|
|
self._print_accuracy('FP32', fp32_acc1, fp32_acc5)
|
|
|
|
self._print_accuracy('FP32', fp32_acc1, fp32_acc5)
|
|
|
|
|
|
|
|
|
|
|
|
self._summarize_performance(int8_fps, int8_lat, fp32_fps, fp32_lat)
|
|
|
|
if {'int8', 'fp32'}.issubset(self._targets):
|
|
|
|
self._summarize_accuracy(quant_acc1, quant_acc5, int8_acc1, int8_acc5,
|
|
|
|
self._summarize_performance(int8_fps, int8_lat, fp32_fps, fp32_lat)
|
|
|
|
fp32_acc1, fp32_acc5)
|
|
|
|
if {'int8', 'quant'}.issubset(self._targets):
|
|
|
|
self._compare_accuracy(acc_diff_threshold, quant_acc1, int8_acc1)
|
|
|
|
self._summarize_accuracy(quant_acc1, quant_acc5, int8_acc1,
|
|
|
|
|
|
|
|
int8_acc5, fp32_acc1, fp32_acc5)
|
|
|
|
|
|
|
|
self._compare_accuracy(acc_diff_threshold, quant_acc1, int8_acc1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|