|
|
|
@ -123,7 +123,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
|
|
|
|
|
arg_name.endswith('.quantized.dequantized'))
|
|
|
|
|
self.assertTrue(arg_name in quantized_ops)
|
|
|
|
|
|
|
|
|
|
def linear_fc_quant(self, quant_type, enable_ce=False):
|
|
|
|
|
def linear_fc_quant(self, quant_type, for_ci=False):
|
|
|
|
|
main = fluid.Program()
|
|
|
|
|
startup = fluid.Program()
|
|
|
|
|
with fluid.program_guard(main, startup):
|
|
|
|
@ -138,7 +138,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
|
|
|
|
|
place=place,
|
|
|
|
|
activation_quantize_type=quant_type)
|
|
|
|
|
transform_pass.apply(graph)
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
marked_nodes = set()
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name().find('quantize') > -1:
|
|
|
|
@ -147,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
|
|
|
|
|
program = graph.to_program()
|
|
|
|
|
self.check_program(transform_pass, program)
|
|
|
|
|
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
val_marked_nodes = set()
|
|
|
|
|
for op in val_graph.all_op_nodes():
|
|
|
|
|
if op.name().find('quantize') > -1:
|
|
|
|
@ -155,12 +155,12 @@ class TestQuantizationTransformPass(unittest.TestCase):
|
|
|
|
|
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
|
|
|
|
|
|
|
|
|
|
def test_linear_fc_quant_abs_max(self):
|
|
|
|
|
self.linear_fc_quant('abs_max', enable_ce=True)
|
|
|
|
|
self.linear_fc_quant('abs_max', for_ci=True)
|
|
|
|
|
|
|
|
|
|
def test_linear_fc_quant_range_abs_max(self):
|
|
|
|
|
self.linear_fc_quant('range_abs_max', enable_ce=True)
|
|
|
|
|
self.linear_fc_quant('range_abs_max', for_ci=True)
|
|
|
|
|
|
|
|
|
|
def residual_block_quant(self, quant_type, enable_ce=False):
|
|
|
|
|
def residual_block_quant(self, quant_type, for_ci=False):
|
|
|
|
|
main = fluid.Program()
|
|
|
|
|
startup = fluid.Program()
|
|
|
|
|
with fluid.program_guard(main, startup):
|
|
|
|
@ -175,7 +175,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
|
|
|
|
|
place=place,
|
|
|
|
|
activation_quantize_type=quant_type)
|
|
|
|
|
transform_pass.apply(graph)
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
marked_nodes = set()
|
|
|
|
|
for op in graph.all_op_nodes():
|
|
|
|
|
if op.name().find('quantize') > -1:
|
|
|
|
@ -184,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
|
|
|
|
|
program = graph.to_program()
|
|
|
|
|
self.check_program(transform_pass, program)
|
|
|
|
|
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
val_marked_nodes = set()
|
|
|
|
|
for op in val_graph.all_op_nodes():
|
|
|
|
|
if op.name().find('quantize') > -1:
|
|
|
|
@ -192,14 +192,14 @@ class TestQuantizationTransformPass(unittest.TestCase):
|
|
|
|
|
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
|
|
|
|
|
|
|
|
|
|
def test_residual_block_abs_max(self):
|
|
|
|
|
self.residual_block_quant('abs_max', enable_ce=True)
|
|
|
|
|
self.residual_block_quant('abs_max', for_ci=True)
|
|
|
|
|
|
|
|
|
|
def test_residual_block_range_abs_max(self):
|
|
|
|
|
self.residual_block_quant('range_abs_max', enable_ce=True)
|
|
|
|
|
self.residual_block_quant('range_abs_max', for_ci=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
def freeze_graph(self, use_cuda, seed, quant_type, enable_ce=False):
|
|
|
|
|
def freeze_graph(self, use_cuda, seed, quant_type, for_ci=False):
|
|
|
|
|
def build_program(main, startup, is_test):
|
|
|
|
|
main.random_seed = seed
|
|
|
|
|
startup.random_seed = seed
|
|
|
|
@ -237,7 +237,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
transform_pass.apply(main_graph)
|
|
|
|
|
transform_pass.apply(test_graph)
|
|
|
|
|
dev_name = '_gpu_' if use_cuda else '_cpu_'
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
marked_nodes = set()
|
|
|
|
|
for op in main_graph.all_op_nodes():
|
|
|
|
|
if op.name().find('quantize') > -1:
|
|
|
|
@ -267,7 +267,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
loss_v = exe.run(program=quantized_main_program,
|
|
|
|
|
feed=feeder.feed(data),
|
|
|
|
|
fetch_list=[loss])
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
print('{}: {}'.format('loss' + dev_name + quant_type,
|
|
|
|
|
loss_v))
|
|
|
|
|
|
|
|
|
@ -284,7 +284,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
# Freeze graph for inference, but the weight of fc/conv is still float type.
|
|
|
|
|
freeze_pass = QuantizationFreezePass(scope=scope, place=place)
|
|
|
|
|
freeze_pass.apply(test_graph)
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
marked_nodes = set()
|
|
|
|
|
for op in test_graph.all_op_nodes():
|
|
|
|
|
if op.name().find('quantize') > -1:
|
|
|
|
@ -298,7 +298,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
feed=feeder.feed(test_data),
|
|
|
|
|
fetch_list=[loss])
|
|
|
|
|
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
print('{}: {}'.format('test_loss1' + dev_name + quant_type,
|
|
|
|
|
test_loss1))
|
|
|
|
|
print('{}: {}'.format('test_loss2' + dev_name + quant_type,
|
|
|
|
@ -306,7 +306,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
|
|
|
|
|
# Maybe failed, this is due to the calculation precision
|
|
|
|
|
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
|
|
|
|
|
np.sum(w_freeze)))
|
|
|
|
|
print('{}: {}'.format('w_quant' + dev_name + quant_type,
|
|
|
|
@ -315,7 +315,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
# Convert parameter to 8-bit.
|
|
|
|
|
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
|
|
|
|
|
convert_int8_pass.apply(test_graph)
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
marked_nodes = set()
|
|
|
|
|
for op in test_graph.all_op_nodes():
|
|
|
|
|
if op.name().find('quantize') > -1:
|
|
|
|
@ -335,7 +335,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
|
|
|
|
|
self.assertEqual(w_8bit.dtype, np.int8)
|
|
|
|
|
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
print('{}: {}'.format('w_8bit' + dev_name + quant_type,
|
|
|
|
|
np.sum(w_8bit)))
|
|
|
|
|
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
|
|
|
|
@ -343,7 +343,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
mobile_pass = TransformForMobilePass()
|
|
|
|
|
mobile_pass.apply(test_graph)
|
|
|
|
|
if not enable_ce:
|
|
|
|
|
if not for_ci:
|
|
|
|
|
marked_nodes = set()
|
|
|
|
|
for op in test_graph.all_op_nodes():
|
|
|
|
|
if op.name().find('quantize') > -1:
|
|
|
|
@ -361,23 +361,22 @@ class TestQuantizationFreezePass(unittest.TestCase):
|
|
|
|
|
if fluid.core.is_compiled_with_cuda():
|
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
|
self.freeze_graph(
|
|
|
|
|
True, seed=1, quant_type='abs_max', enable_ce=True)
|
|
|
|
|
True, seed=1, quant_type='abs_max', for_ci=True)
|
|
|
|
|
|
|
|
|
|
def test_freeze_graph_cpu_dynamic(self):
|
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
|
self.freeze_graph(
|
|
|
|
|
False, seed=2, quant_type='abs_max', enable_ce=True)
|
|
|
|
|
self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=True)
|
|
|
|
|
|
|
|
|
|
def test_freeze_graph_cuda_static(self):
|
|
|
|
|
if fluid.core.is_compiled_with_cuda():
|
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
|
self.freeze_graph(
|
|
|
|
|
True, seed=1, quant_type='range_abs_max', enable_ce=True)
|
|
|
|
|
True, seed=1, quant_type='range_abs_max', for_ci=True)
|
|
|
|
|
|
|
|
|
|
def test_freeze_graph_cpu_static(self):
|
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
|
self.freeze_graph(
|
|
|
|
|
False, seed=2, quant_type='range_abs_max', enable_ce=True)
|
|
|
|
|
False, seed=2, quant_type='range_abs_max', for_ci=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|