Adding AddQuantDequantPass for TensorRT int8 (#17529)

* add quant_dequant_pass, test=develop

* Add quant_dequant before some ops, such as the elementwise_add op. This is required by TensorRT. test=develop
fix_ema
Zhen Wang 6 years ago committed by GitHub
parent f9796b1249
commit 3398f99608
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,7 +22,8 @@ from .... import unique_name
__all__ = [ __all__ = [
'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass', 'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass' 'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass',
'AddQuantDequantPass'
] ]
@ -994,6 +995,8 @@ class ScaleForTrainingPass(object):
Args: Args:
graph(IrGraph): the target graph. graph(IrGraph): the target graph.
""" """
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() self._is_test = graph.is_test()
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
@ -1099,6 +1102,8 @@ class ScaleForInferencePass(object):
Args: Args:
graph(IrGraph): the target graph. graph(IrGraph): the target graph.
""" """
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
name = op_node.name() name = op_node.name()
@ -1117,3 +1122,137 @@ class ScaleForInferencePass(object):
Return the scale name for the var named `var_name`. Return the scale name for the var named `var_name`.
""" """
return "%s@scale" % (var_name) return "%s@scale" % (var_name)
class AddQuantDequantPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8):
"""
This pass is used to add quant_dequant op for some ops, such as the
`elementwise_add` op.
"""
self._scope = scope
self._place = place
self._moving_rate = moving_rate
self._quant_bits = quant_bits
self._is_test = None
self._target_ops = ["elementwise_add", "pool2d"]
def apply(self, graph):
"""
Add quant_dequant before some ops, such as the `elementwise_add` op. This
is required by TensorRT.
Args:
graph(IrGraph): the target graph.
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test()
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in self._target_ops:
in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
in_nodes_all_not_persistable = (
in_nodes_all_not_persistable and
not in_node.persistable())
if not in_nodes_all_not_persistable:
continue
input_names = op_node.input_arg_names()
for input_name in input_names:
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
quant_var_node, scale_var_node = self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits)
graph.update_input_link(in_node, quant_var_node, op_node)
graph.resolve_hazard()
return graph
def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
quant_bits):
"""Insert fake_quantize_dequantize_moving_average_abs_max op.
"""
quant_var_node = graph.create_var_node(
name="{}.quant_dequant".format(var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_in_node = graph.create_persistable_node(
name="{}.quant_dequant.scale".format(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(
scale_in_node,
np.array(
[0.001], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node}
outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
if not self._is_test:
state_in_node = graph.create_persistable_node(
name=unique_name.generate('quant_dequant.state'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(
state_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('quant_dequant.accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
_init_var_node(
accum_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
))
ins['InState'] = state_in_node
ins['InAccum'] = accum_in_node
outs['OutState'] = state_out_node
outs['OutAccum'] = accum_out_node
attrs = {
'bit_length': quant_bits,
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_dequantize_moving_average_abs_max',
attrs=attrs,
inputs=ins,
outputs=outs)
graph.link_to(var_node, quant_op_node)
graph.link_to(scale_in_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_out_node)
if not self._is_test:
graph.link_to(state_in_node, quant_op_node)
graph.link_to(accum_in_node, quant_op_node)
graph.link_to(quant_op_node, state_out_node)
graph.link_to(quant_op_node, accum_out_node)
return quant_var_node, scale_out_node

@ -24,6 +24,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import ScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import ScaleForInferencePass from paddle.fluid.contrib.slim.quantization import ScaleForInferencePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
@ -98,6 +99,7 @@ class TestQuantizationScalePass(unittest.TestCase):
scope = fluid.Scope() scope = fluid.Scope()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
exe.run(startup) exe.run(startup)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, scope=scope,
place=place, place=place,
@ -105,8 +107,14 @@ class TestQuantizationScalePass(unittest.TestCase):
weight_quantize_type=weight_quant_type) weight_quantize_type=weight_quant_type)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place)
add_quant_dequant_pass.apply(main_graph)
add_quant_dequant_pass.apply(test_graph)
scale_training_pass = ScaleForTrainingPass(scope=scope, place=place) scale_training_pass = ScaleForTrainingPass(scope=scope, place=place)
scale_training_pass.apply(main_graph) scale_training_pass.apply(main_graph)
dev_name = '_gpu' if use_cuda else '_cpu' dev_name = '_gpu' if use_cuda else '_cpu'
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()

Loading…
Cancel
Save