|
|
|
@ -531,7 +531,7 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
|
|
|
|
|
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
|
|
|
|
|
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
|
|
|
|
|
|
|
|
|
|
if quant_skip_pattern:
|
|
|
|
|
if isinstance(quant_skip_pattern, str):
|
|
|
|
|
with fluid.name_scope(quant_skip_pattern):
|
|
|
|
|
pool1 = fluid.layers.pool2d(
|
|
|
|
|
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
|
|
|
|
@ -539,6 +539,18 @@ def quant_dequant_residual_block(num, quant_skip_pattern=None):
|
|
|
|
|
input=hidden, pool_size=2, pool_type='max', pool_stride=2)
|
|
|
|
|
pool_add = fluid.layers.elementwise_add(
|
|
|
|
|
x=pool1, y=pool2, act='relu')
|
|
|
|
|
elif isinstance(quant_skip_pattern, list):
|
|
|
|
|
assert len(
|
|
|
|
|
quant_skip_pattern
|
|
|
|
|
) > 1, 'test config error: the len of quant_skip_pattern list should be greater than 1.'
|
|
|
|
|
with fluid.name_scope(quant_skip_pattern[0]):
|
|
|
|
|
pool1 = fluid.layers.pool2d(
|
|
|
|
|
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
|
|
|
|
|
pool2 = fluid.layers.pool2d(
|
|
|
|
|
input=hidden, pool_size=2, pool_type='max', pool_stride=2)
|
|
|
|
|
with fluid.name_scope(quant_skip_pattern[1]):
|
|
|
|
|
pool_add = fluid.layers.elementwise_add(
|
|
|
|
|
x=pool1, y=pool2, act='relu')
|
|
|
|
|
else:
|
|
|
|
|
pool1 = fluid.layers.pool2d(
|
|
|
|
|
input=hidden, pool_size=2, pool_type='avg', pool_stride=2)
|
|
|
|
@ -560,8 +572,15 @@ class TestAddQuantDequantPass(unittest.TestCase):
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
if op_node.name() in self._target_ops:
|
|
|
|
|
if skip_pattern and op_node.op().has_attr("op_namescope") and \
|
|
|
|
|
op_node.op().attr("op_namescope").find(skip_pattern) != -1:
|
|
|
|
|
user_skipped = False
|
|
|
|
|
if isinstance(skip_pattern, list):
|
|
|
|
|
user_skipped = op_node.op().has_attr("op_namescope") and \
|
|
|
|
|
any(pattern in op_node.op().attr("op_namescope") for pattern in skip_pattern)
|
|
|
|
|
elif isinstance(skip_pattern, str):
|
|
|
|
|
user_skipped = op_node.op().has_attr("op_namescope") and \
|
|
|
|
|
op_node.op().attr("op_namescope").find(skip_pattern) != -1
|
|
|
|
|
|
|
|
|
|
if user_skipped:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
in_nodes_all_not_persistable = True
|
|
|
|
@ -587,7 +606,7 @@ class TestAddQuantDequantPass(unittest.TestCase):
|
|
|
|
|
place = fluid.CPUPlace()
|
|
|
|
|
graph = IrGraph(core.Graph(main.desc), for_test=False)
|
|
|
|
|
add_quant_dequant_pass = AddQuantDequantPass(
|
|
|
|
|
scope=fluid.global_scope(), place=place)
|
|
|
|
|
scope=fluid.global_scope(), place=place, skip_pattern=skip_pattern)
|
|
|
|
|
add_quant_dequant_pass.apply(graph)
|
|
|
|
|
if not for_ci:
|
|
|
|
|
marked_nodes = set()
|
|
|
|
@ -611,6 +630,10 @@ class TestAddQuantDequantPass(unittest.TestCase):
|
|
|
|
|
def test_residual_block_skip_pattern(self):
|
|
|
|
|
self.residual_block_quant(skip_pattern='skip_quant', for_ci=True)
|
|
|
|
|
|
|
|
|
|
def test_residual_block_skip_pattern(self):
|
|
|
|
|
self.residual_block_quant(
|
|
|
|
|
skip_pattern=['skip_quant1', 'skip_quant2'], for_ci=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|