|
|
|
@ -231,14 +231,14 @@ class QuantizationTransformPass(object):
|
|
|
|
|
|
|
|
|
|
quant_var_node = graph.create_var_node(
|
|
|
|
|
name=self._quantized_var_name(var_node.name()),
|
|
|
|
|
var_type=var_node.var().type(),
|
|
|
|
|
shape=var_node.var().shape(),
|
|
|
|
|
var_dtype=var_node.var().dtype())
|
|
|
|
|
var_type=var_node.type(),
|
|
|
|
|
shape=var_node.shape(),
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
scale_var_node = graph.create_var_node(
|
|
|
|
|
name=self._quantized_scale_name(var_node.name()),
|
|
|
|
|
var_type=var_node.var().type(),
|
|
|
|
|
shape=var_node.var().shape(),
|
|
|
|
|
var_dtype=var_node.var().dtype())
|
|
|
|
|
var_type=var_node.type(),
|
|
|
|
|
shape=var_node.shape(),
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
quant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='fake_quantize_abs_max',
|
|
|
|
|
attrs={
|
|
|
|
@ -261,15 +261,15 @@ class QuantizationTransformPass(object):
|
|
|
|
|
|
|
|
|
|
quant_var_node = graph.create_var_node(
|
|
|
|
|
name=self._quantized_var_name(var_node.name()),
|
|
|
|
|
var_type=var_node.var().type(),
|
|
|
|
|
shape=var_node.var().shape(),
|
|
|
|
|
var_dtype=var_node.var().dtype())
|
|
|
|
|
var_type=var_node.type(),
|
|
|
|
|
shape=var_node.shape(),
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
|
|
|
|
|
scale_in_node = graph.create_persistable_node(
|
|
|
|
|
name=self._quantized_scale_name(var_node.name()),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
shape=[1],
|
|
|
|
|
var_dtype=var_node.var().dtype())
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
|
|
|
|
|
|
|
|
|
|
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
|
|
|
|
@ -282,7 +282,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
name=unique_name.generate('scales'),
|
|
|
|
|
var_type=core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
shape=[self._window_size],
|
|
|
|
|
var_dtype=var_node.var().dtype())
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
self._need_initialized[scales_node.var()] = Constant(value=0)
|
|
|
|
|
inputs['Iter'] = self._global_step
|
|
|
|
|
outputs['OutScales'] = scales_node
|
|
|
|
@ -317,9 +317,9 @@ class QuantizationTransformPass(object):
|
|
|
|
|
|
|
|
|
|
dequant_var_node = graph.create_var_node(
|
|
|
|
|
name=self._dequantized_var_name(var_node.name()),
|
|
|
|
|
var_type=var_node.var().type(),
|
|
|
|
|
shape=var_node.var().shape(),
|
|
|
|
|
var_dtype=var_node.var().dtype())
|
|
|
|
|
var_type=var_node.type(),
|
|
|
|
|
shape=var_node.shape(),
|
|
|
|
|
var_dtype=var_node.dtype())
|
|
|
|
|
max_range = (1 << (quant_bits - 1)) - 1
|
|
|
|
|
dequant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='fake_dequantize_max_abs',
|
|
|
|
@ -408,17 +408,17 @@ class QuantizationFreezePass(object):
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
|
if op_name in self._fake_quant_op_names:
|
|
|
|
|
input_arg_name = op_node.op().input('X')[0]
|
|
|
|
|
input_arg_name = op_node.input('X')[0]
|
|
|
|
|
if input_arg_name in persistable_vars:
|
|
|
|
|
if self._weight_quantize_type == 'abs_max':
|
|
|
|
|
param = self._load_var(input_arg_name)
|
|
|
|
|
scale_v = np.max(np.abs(param))
|
|
|
|
|
else:
|
|
|
|
|
scale_v = self._load_var(op_node.op().output('OutScale')
|
|
|
|
|
[0])[0]
|
|
|
|
|
scale_v = self._load_var(
|
|
|
|
|
op_node.output('OutScale')[0])[0]
|
|
|
|
|
self._var_scale_map[input_arg_name] = scale_v
|
|
|
|
|
else:
|
|
|
|
|
scale_v = graph.var_node(op_node.op().output('OutScale')[0])
|
|
|
|
|
scale_v = graph.var_node(op_node.output('OutScale')[0])
|
|
|
|
|
self._var_scale_map[input_arg_name] = scale_v
|
|
|
|
|
if input_arg_name in persistable_vars:
|
|
|
|
|
self._remove_fake_quant_and_dequant_op(graph, op_node)
|
|
|
|
@ -454,8 +454,8 @@ class QuantizationFreezePass(object):
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
|
|
|
|
|
k = op_node.op().output('Out')[0]
|
|
|
|
|
v = op_node.op().input('X')[0]
|
|
|
|
|
k = op_node.output('Out')[0]
|
|
|
|
|
v = op_node.input('X')[0]
|
|
|
|
|
if v not in self._op_input_rename_map:
|
|
|
|
|
self._op_input_rename_map[k] = v
|
|
|
|
|
else:
|
|
|
|
@ -493,9 +493,9 @@ class QuantizationFreezePass(object):
|
|
|
|
|
output_var_node = op_node.outputs[0]
|
|
|
|
|
dequant_var_node = graph.create_var_node(
|
|
|
|
|
name=self._dequantized_var_name(output_var_node.name()),
|
|
|
|
|
var_type=output_var_node.var().type(),
|
|
|
|
|
shape=output_var_node.var().shape(),
|
|
|
|
|
var_dtype=output_var_node.var().dtype())
|
|
|
|
|
var_type=output_var_node.type(),
|
|
|
|
|
shape=output_var_node.shape(),
|
|
|
|
|
var_dtype=output_var_node.dtype())
|
|
|
|
|
dequant_op_node = graph.create_op_node(
|
|
|
|
|
op_type='fake_dequantize_max_abs',
|
|
|
|
|
attrs={
|
|
|
|
@ -615,8 +615,8 @@ class ConvertToInt8Pass(object):
|
|
|
|
|
int8_var_node_name = var_node.name() + ".int8"
|
|
|
|
|
int8_var_node = graph.create_persistable_node(
|
|
|
|
|
name=cpt.to_text(int8_var_node_name),
|
|
|
|
|
var_type=var_node.var().type(),
|
|
|
|
|
shape=var_node.var().shape(),
|
|
|
|
|
var_type=var_node.type(),
|
|
|
|
|
shape=var_node.shape(),
|
|
|
|
|
var_dtype=core.VarDesc.VarType.INT8)
|
|
|
|
|
array = self._load_var(var_node.name())
|
|
|
|
|
self._scope.var(int8_var_node_name)
|
|
|
|
@ -672,7 +672,7 @@ class TransformForMobilePass(object):
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
name = op_node.name()
|
|
|
|
|
if name in self._fake_quant_op_names:
|
|
|
|
|
op_node.op().set_type('quantize')
|
|
|
|
|
op_node.set_type('quantize')
|
|
|
|
|
quant_node = graph.create_op_node_from_desc(op_node.op())
|
|
|
|
|
for input_node in op_node.inputs:
|
|
|
|
|
graph.link_to(input_node, quant_node)
|
|
|
|
@ -680,7 +680,7 @@ class TransformForMobilePass(object):
|
|
|
|
|
graph.link_to(quant_node, output_node)
|
|
|
|
|
graph.safe_remove_nodes(op_node)
|
|
|
|
|
if name in self._fake_dequant_op_names:
|
|
|
|
|
op_node.op().set_type('dequantize')
|
|
|
|
|
op_node.set_type('dequantize')
|
|
|
|
|
dequant_node = graph.create_op_node_from_desc(op_node.op())
|
|
|
|
|
for input_node in op_node.inputs:
|
|
|
|
|
graph.link_to(input_node, dequant_node)
|
|
|
|
|