|
|
|
@ -180,9 +180,14 @@ class QuantizationTransformPass(object):
|
|
|
|
|
Constant(value=0, force_cpu=True)
|
|
|
|
|
global_step_out = graph.create_var_node_from_desc(
|
|
|
|
|
global_step_in.var())
|
|
|
|
|
# The attribute of `op_role` is needed by ParallelExecutor.
|
|
|
|
|
increment_op = graph.create_op_node(
|
|
|
|
|
op_type='increment',
|
|
|
|
|
attrs={'step': 1.0},
|
|
|
|
|
attrs={
|
|
|
|
|
'step': 1.0,
|
|
|
|
|
'op_role':
|
|
|
|
|
core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
|
},
|
|
|
|
|
inputs={'X': global_step_in},
|
|
|
|
|
outputs={'Out': global_step_out})
|
|
|
|
|
graph.link_to(global_step_in, increment_op)
|
|
|
|
@ -217,7 +222,10 @@ class QuantizationTransformPass(object):
|
|
|
|
|
var_dtype=var_node.var().dtype())
|
|
|
|
|
quant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='fake_quantize_abs_max',
|
|
|
|
|
attrs={'bit_length': quant_bits},
|
|
|
|
|
attrs={
|
|
|
|
|
'bit_length': quant_bits,
|
|
|
|
|
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
|
},
|
|
|
|
|
inputs={'X': var_node},
|
|
|
|
|
outputs={'Out': quant_var_node,
|
|
|
|
|
'OutScale': scale_var_node})
|
|
|
|
@ -262,7 +270,8 @@ class QuantizationTransformPass(object):
|
|
|
|
|
attrs = {
|
|
|
|
|
'window_size': self._window_size,
|
|
|
|
|
'bit_length': quant_bits,
|
|
|
|
|
'is_test': self._is_test
|
|
|
|
|
'is_test': self._is_test,
|
|
|
|
|
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
|
}
|
|
|
|
|
quant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='fake_quantize_range_abs_max',
|
|
|
|
@ -295,7 +304,10 @@ class QuantizationTransformPass(object):
|
|
|
|
|
max_range = (1 << (quant_bits - 1)) - 1
|
|
|
|
|
dequant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='fake_dequantize_max_abs',
|
|
|
|
|
attrs={'max_range': float(max_range)},
|
|
|
|
|
attrs={
|
|
|
|
|
'max_range': float(max_range),
|
|
|
|
|
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
|
},
|
|
|
|
|
inputs={'X': var_node,
|
|
|
|
|
'Scale': scale_var_node},
|
|
|
|
|
outputs={'Out': dequant_var_node})
|
|
|
|
@ -444,7 +456,10 @@ class QuantizationFreezePass(object):
|
|
|
|
|
var_dtype=output_var_node.var().dtype())
|
|
|
|
|
dequant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='fake_dequantize_max_abs',
|
|
|
|
|
attrs={'max_range': float(max_range)},
|
|
|
|
|
attrs={
|
|
|
|
|
'max_range': float(max_range),
|
|
|
|
|
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
|
},
|
|
|
|
|
inputs={'X': output_var_node,
|
|
|
|
|
'Scale': scale_var_node},
|
|
|
|
|
outputs={'Out': dequant_var_node})
|
|
|
|
|