|
|
|
@ -22,7 +22,8 @@ from .... import unique_name
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
|
|
|
|
|
'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass'
|
|
|
|
|
'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass',
|
|
|
|
|
'AddQuantDequantPass'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -994,6 +995,8 @@ class ScaleForTrainingPass(object):
|
|
|
|
|
Args:
|
|
|
|
|
graph(IrGraph): the target graph.
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(graph,
|
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
|
self._is_test = graph.is_test()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
@ -1099,6 +1102,8 @@ class ScaleForInferencePass(object):
|
|
|
|
|
Args:
|
|
|
|
|
graph(IrGraph): the target graph.
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(graph,
|
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
name = op_node.name()
|
|
|
|
@ -1117,3 +1122,137 @@ class ScaleForInferencePass(object):
|
|
|
|
|
Return the scale name for the var named `var_name`.
|
|
|
|
|
"""
|
|
|
|
|
return "%s@scale" % (var_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AddQuantDequantPass(object):
|
|
|
|
|
def __init__(self, scope=None, place=None, moving_rate=0.9, quant_bits=8):
|
|
|
|
|
"""
|
|
|
|
|
This pass is used to add quant_dequant op for some ops, such as the
|
|
|
|
|
`elementwise_add` op.
|
|
|
|
|
"""
|
|
|
|
|
self._scope = scope
|
|
|
|
|
self._place = place
|
|
|
|
|
self._moving_rate = moving_rate
|
|
|
|
|
self._quant_bits = quant_bits
|
|
|
|
|
self._is_test = None
|
|
|
|
|
self._target_ops = ["elementwise_add", "pool2d"]
|
|
|
|
|
|
|
|
|
|
def apply(self, graph):
|
|
|
|
|
"""
|
|
|
|
|
Add quant_dequant before some ops, such as the `elementwise_add` op. This
|
|
|
|
|
is required by TensorRT.
|
|
|
|
|
Args:
|
|
|
|
|
graph(IrGraph): the target graph.
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(graph,
|
|
|
|
|
IrGraph), 'graph must be the instance of IrGraph.'
|
|
|
|
|
self._is_test = graph.is_test()
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
name = op_node.name()
|
|
|
|
|
if name in self._target_ops:
|
|
|
|
|
in_nodes_all_not_persistable = True
|
|
|
|
|
for input_name in op_node.input_arg_names():
|
|
|
|
|
in_node = graph._find_node_by_name(op_node.inputs,
|
|
|
|
|
input_name)
|
|
|
|
|
in_nodes_all_not_persistable = (
|
|
|
|
|
in_nodes_all_not_persistable and
|
|
|
|
|
not in_node.persistable())
|
|
|
|
|
if not in_nodes_all_not_persistable:
|
|
|
|
|
continue
|
|
|
|
|
input_names = op_node.input_arg_names()
|
|
|
|
|
for input_name in input_names:
|
|
|
|
|
in_node = graph._find_node_by_name(op_node.inputs,
|
|
|
|
|
input_name)
|
|
|
|
|
quant_var_node, scale_var_node = self._inser_quant_dequant_moving_average_abs_max_op(
|
|
|
|
|
graph, in_node, self._quant_bits)
|
|
|
|
|
graph.update_input_link(in_node, quant_var_node, op_node)
|
|
|
|
|
graph.resolve_hazard()
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
|
|
|
|
|
quant_bits):
|
|
|
|
|
"""Insert fake_quantize_dequantize_moving_average_abs_max op.
|
|
|
|
|
"""
|
|
|
|
|
quant_var_node = graph.create_var_node(
|
|
|
|
|
name="{}.quant_dequant".format(var_node.name()),
|
|
|
|
|
var_type=var_node.type(),
|
|
|
|
|
shape=var_node.shape(),
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
scale_in_node = graph.create_persistable_node(
|
|
|
|
|
name="{}.quant_dequant.scale".format(var_node.name()),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
shape=[1],
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
data_type = 'float64' if var_node.dtype(
|
|
|
|
|
) == core.VarDesc.VarType.FP64 else 'float32'
|
|
|
|
|
_init_var_node(
|
|
|
|
|
scale_in_node,
|
|
|
|
|
np.array(
|
|
|
|
|
[0.001], dtype=data_type),
|
|
|
|
|
self._scope,
|
|
|
|
|
self._place)
|
|
|
|
|
|
|
|
|
|
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
|
|
|
|
|
ins = {'X': var_node, 'InScale': scale_in_node}
|
|
|
|
|
outs = {'Out': quant_var_node, 'OutScale': scale_out_node}
|
|
|
|
|
if not self._is_test:
|
|
|
|
|
state_in_node = graph.create_persistable_node(
|
|
|
|
|
name=unique_name.generate('quant_dequant.state'),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
var_dtype=var_node.dtype(),
|
|
|
|
|
shape=[1])
|
|
|
|
|
data_type = 'float64' if var_node.dtype(
|
|
|
|
|
) == core.VarDesc.VarType.FP64 else 'float32'
|
|
|
|
|
_init_var_node(
|
|
|
|
|
state_in_node,
|
|
|
|
|
np.ones(
|
|
|
|
|
[1], dtype=data_type),
|
|
|
|
|
self._scope,
|
|
|
|
|
self._place)
|
|
|
|
|
accum_in_node = graph.create_persistable_node(
|
|
|
|
|
name=unique_name.generate('quant_dequant.accum'),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
var_dtype=var_node.dtype(),
|
|
|
|
|
shape=[1])
|
|
|
|
|
_init_var_node(
|
|
|
|
|
accum_in_node,
|
|
|
|
|
np.ones(
|
|
|
|
|
[1], dtype=data_type),
|
|
|
|
|
self._scope,
|
|
|
|
|
self._place)
|
|
|
|
|
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
|
|
|
|
|
))
|
|
|
|
|
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
ins['InState'] = state_in_node
|
|
|
|
|
ins['InAccum'] = accum_in_node
|
|
|
|
|
outs['OutState'] = state_out_node
|
|
|
|
|
outs['OutAccum'] = accum_out_node
|
|
|
|
|
|
|
|
|
|
attrs = {
|
|
|
|
|
'bit_length': quant_bits,
|
|
|
|
|
'moving_rate': self._moving_rate,
|
|
|
|
|
'is_test': self._is_test,
|
|
|
|
|
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
quant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='fake_quantize_dequantize_moving_average_abs_max',
|
|
|
|
|
attrs=attrs,
|
|
|
|
|
inputs=ins,
|
|
|
|
|
outputs=outs)
|
|
|
|
|
|
|
|
|
|
graph.link_to(var_node, quant_op_node)
|
|
|
|
|
graph.link_to(scale_in_node, quant_op_node)
|
|
|
|
|
graph.link_to(quant_op_node, quant_var_node)
|
|
|
|
|
graph.link_to(quant_op_node, scale_out_node)
|
|
|
|
|
|
|
|
|
|
if not self._is_test:
|
|
|
|
|
graph.link_to(state_in_node, quant_op_node)
|
|
|
|
|
graph.link_to(accum_in_node, quant_op_node)
|
|
|
|
|
graph.link_to(quant_op_node, state_out_node)
|
|
|
|
|
graph.link_to(quant_op_node, accum_out_node)
|
|
|
|
|
|
|
|
|
|
return quant_var_node, scale_out_node
|
|
|
|
|