|
|
@ -758,6 +758,7 @@ class QuantizationTransformPass(object):
|
|
|
|
attrs={
|
|
|
|
attrs={
|
|
|
|
'bit_length': quant_bits,
|
|
|
|
'bit_length': quant_bits,
|
|
|
|
'quant_axis': quant_axis,
|
|
|
|
'quant_axis': quant_axis,
|
|
|
|
|
|
|
|
'is_test': self._is_test,
|
|
|
|
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
|
|
|
|
},
|
|
|
|
},
|
|
|
|
inputs={'X': var_node},
|
|
|
|
inputs={'X': var_node},
|
|
|
@ -1331,16 +1332,25 @@ class QuantizationFreezePass(object):
|
|
|
|
|
|
|
|
|
|
|
|
def _quant(self, x, scale, num_bits, quant_axis):
|
|
|
|
def _quant(self, x, scale, num_bits, quant_axis):
|
|
|
|
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
|
|
|
|
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
|
|
|
|
|
|
|
|
bnt = (1 << (num_bits - 1)) - 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _clip(x, scale):
|
|
|
|
|
|
|
|
x[x > scale] = scale
|
|
|
|
|
|
|
|
x[x < -scale] = -scale
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(scale, list):
|
|
|
|
if isinstance(scale, list):
|
|
|
|
for i, s in enumerate(scale):
|
|
|
|
for i, s in enumerate(scale):
|
|
|
|
if quant_axis == 0:
|
|
|
|
if quant_axis == 0:
|
|
|
|
x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1))
|
|
|
|
x[i] = _clip(x[i], s)
|
|
|
|
|
|
|
|
x[i] = np.round(x[i] / s * bnt)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
x[:, i] = np.round(x[:, i] / s * (
|
|
|
|
x[:, i] = _clip(x[:, i], s)
|
|
|
|
(1 << (num_bits - 1)) - 1))
|
|
|
|
x[:, i] = np.round(x[:, i] / s * bnt)
|
|
|
|
return x
|
|
|
|
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
|
|
|
|
x = _clip(x, scale)
|
|
|
|
|
|
|
|
x = np.round(x / scale * bnt)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvertToInt8Pass(object):
|
|
|
|
class ConvertToInt8Pass(object):
|
|
|
|