|
|
|
@ -758,6 +758,7 @@ class QuantizationTransformPass(object):
|
|
|
|
|
attrs={
|
|
|
|
|
'bit_length': quant_bits,
|
|
|
|
|
'quant_axis': quant_axis,
|
|
|
|
|
'is_test': self._is_test,
|
|
|
|
|
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
|
},
|
|
|
|
|
inputs={'X': var_node},
|
|
|
|
@ -1125,7 +1126,7 @@ class QuantizationFreezePass(object):
|
|
|
|
|
self._restore_var(input_arg_name, quantized_param_v)
|
|
|
|
|
self._remove_fake_quant_and_dequant_op(graph, op_node)
|
|
|
|
|
|
|
|
|
|
# Remove all fake dequant op
|
|
|
|
|
# Remove all fake dequant op
|
|
|
|
|
ops = graph.all_op_nodes()
|
|
|
|
|
for op_node in ops:
|
|
|
|
|
op_name = op_node.name()
|
|
|
|
@ -1331,16 +1332,25 @@ class QuantizationFreezePass(object):
|
|
|
|
|
|
|
|
|
|
def _quant(self, x, scale, num_bits, quant_axis):
|
|
|
|
|
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
|
|
|
|
|
bnt = (1 << (num_bits - 1)) - 1
|
|
|
|
|
|
|
|
|
|
def _clip(x, scale):
|
|
|
|
|
x[x > scale] = scale
|
|
|
|
|
x[x < -scale] = -scale
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
if isinstance(scale, list):
|
|
|
|
|
for i, s in enumerate(scale):
|
|
|
|
|
if quant_axis == 0:
|
|
|
|
|
x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1))
|
|
|
|
|
x[i] = _clip(x[i], s)
|
|
|
|
|
x[i] = np.round(x[i] / s * bnt)
|
|
|
|
|
else:
|
|
|
|
|
x[:, i] = np.round(x[:, i] / s * (
|
|
|
|
|
(1 << (num_bits - 1)) - 1))
|
|
|
|
|
return x
|
|
|
|
|
x[:, i] = _clip(x[:, i], s)
|
|
|
|
|
x[:, i] = np.round(x[:, i] / s * bnt)
|
|
|
|
|
else:
|
|
|
|
|
return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
|
|
|
|
|
x = _clip(x, scale)
|
|
|
|
|
x = np.round(x / scale * bnt)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvertToInt8Pass(object):
|
|
|
|
|