|
|
|
@ -32,6 +32,7 @@ from ...ops.operations import _inner_ops as inner
|
|
|
|
|
from ...train import serialization
|
|
|
|
|
from . import quant_utils
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
|
|
|
|
nn.ReLU6: quant.ReLU6Quant,
|
|
|
|
|
nn.HSigmoid: quant.HSigmoidQuant,
|
|
|
|
@ -61,14 +62,17 @@ class _AddFakeQuantAfterSubCell(nn.Cell):
|
|
|
|
|
Add FakeQuant after of the sub Cell.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, subcell, quant_delay=0, num_bits=8):
|
|
|
|
|
def __init__(self, subcell, **kwargs):
|
|
|
|
|
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
|
|
|
|
|
self.subcell = subcell
|
|
|
|
|
self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6,
|
|
|
|
|
max_init=6,
|
|
|
|
|
num_bits=num_bits,
|
|
|
|
|
quant_delay=quant_delay,
|
|
|
|
|
ema=True)
|
|
|
|
|
ema=True,
|
|
|
|
|
num_bits=kwargs["num_bits"],
|
|
|
|
|
quant_delay=kwargs["quant_delay"],
|
|
|
|
|
per_channel=kwargs["per_channel"],
|
|
|
|
|
symmetric=kwargs["symmetric"],
|
|
|
|
|
narrow_range=kwargs["narrow_range"])
|
|
|
|
|
|
|
|
|
|
def construct(self, *data):
|
|
|
|
|
output = self.subcell(*data)
|
|
|
|
@ -82,30 +86,20 @@ class ConvertToQuantNetwork:
|
|
|
|
|
"""
|
|
|
|
|
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
network,
|
|
|
|
|
quant_delay=0,
|
|
|
|
|
bn_fold=False,
|
|
|
|
|
freeze_bn=0,
|
|
|
|
|
weight_bits=8,
|
|
|
|
|
act_bits=8,
|
|
|
|
|
per_channel=False,
|
|
|
|
|
symmetric=False,
|
|
|
|
|
narrow_range=False):
|
|
|
|
|
self.network = validator.check_isinstance(
|
|
|
|
|
'network', network, (nn.Cell,))
|
|
|
|
|
self.quant_delay = validator.check_integer(
|
|
|
|
|
"quant delay", quant_delay, 0, Rel.GE)
|
|
|
|
|
self.freeze_bn = validator.check_integer(
|
|
|
|
|
"freeze bn", freeze_bn, 0, Rel.GE)
|
|
|
|
|
self.weight_bits = validator.check_integer(
|
|
|
|
|
"weights bit", weight_bits, 0, Rel.GE)
|
|
|
|
|
self.act_bits = validator.check_integer(
|
|
|
|
|
"activations bit", act_bits, 0, Rel.GE)
|
|
|
|
|
self.bn_fold = validator.check_bool("bn fold", bn_fold)
|
|
|
|
|
self.per_channel = validator.check_bool("per channel", per_channel)
|
|
|
|
|
self.symmetric = validator.check_bool("symmetric", symmetric)
|
|
|
|
|
self.narrow_range = validator.check_bool("narrow range", narrow_range)
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
|
self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
|
|
|
|
|
self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
|
|
|
|
|
self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
|
|
|
|
|
self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"])
|
|
|
|
|
self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
|
|
|
|
|
self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
|
|
|
|
|
self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
|
|
|
|
|
self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0])
|
|
|
|
|
self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1])
|
|
|
|
|
self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0])
|
|
|
|
|
self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1])
|
|
|
|
|
self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0])
|
|
|
|
|
self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1])
|
|
|
|
|
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
|
|
|
|
quant.DenseBnAct: self._convert_dense}
|
|
|
|
|
|
|
|
|
@ -153,7 +147,12 @@ class ConvertToQuantNetwork:
|
|
|
|
|
add_list.append((name, attr))
|
|
|
|
|
for name, prim_op in add_list:
|
|
|
|
|
prefix = name
|
|
|
|
|
add_quant = _AddFakeQuantAfterSubCell(prim_op) # quant.TensorAddQuant()
|
|
|
|
|
add_quant = _AddFakeQuantAfterSubCell(prim_op,
|
|
|
|
|
num_bits=self.act_bits,
|
|
|
|
|
quant_delay=self.act_delay,
|
|
|
|
|
per_channel=self.act_channel,
|
|
|
|
|
symmetric=self.act_symmetric,
|
|
|
|
|
narrow_range=self.act_range)
|
|
|
|
|
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
|
|
|
|
|
add_quant.update_parameters_name(prefix + '.')
|
|
|
|
|
del network.__dict__[name]
|
|
|
|
@ -177,13 +176,13 @@ class ConvertToQuantNetwork:
|
|
|
|
|
group=conv_inner.group,
|
|
|
|
|
eps=bn_inner.eps,
|
|
|
|
|
momentum=bn_inner.momentum,
|
|
|
|
|
quant_delay=self.quant_delay,
|
|
|
|
|
quant_delay=self.weight_qdelay,
|
|
|
|
|
freeze_bn=self.freeze_bn,
|
|
|
|
|
per_channel=self.per_channel,
|
|
|
|
|
per_channel=self.weight_channel,
|
|
|
|
|
num_bits=self.weight_bits,
|
|
|
|
|
fake=True,
|
|
|
|
|
symmetric=self.symmetric,
|
|
|
|
|
narrow_range=self.narrow_range)
|
|
|
|
|
symmetric=self.weight_symmetric,
|
|
|
|
|
narrow_range=self.weight_range)
|
|
|
|
|
del subcell.batchnorm
|
|
|
|
|
subcell.batchnorm = None
|
|
|
|
|
subcell.has_bn = False
|
|
|
|
@ -197,18 +196,22 @@ class ConvertToQuantNetwork:
|
|
|
|
|
dilation=conv_inner.dilation,
|
|
|
|
|
group=conv_inner.group,
|
|
|
|
|
has_bias=conv_inner.has_bias,
|
|
|
|
|
quant_delay=self.quant_delay,
|
|
|
|
|
per_channel=self.per_channel,
|
|
|
|
|
quant_delay=self.weight_qdelay,
|
|
|
|
|
per_channel=self.weight_channel,
|
|
|
|
|
num_bits=self.weight_bits,
|
|
|
|
|
symmetric=self.symmetric,
|
|
|
|
|
narrow_range=self.narrow_range)
|
|
|
|
|
symmetric=self.weight_symmetric,
|
|
|
|
|
narrow_range=self.weight_range)
|
|
|
|
|
subcell.conv = conv_inner
|
|
|
|
|
if subcell.has_act and subcell.activation is not None:
|
|
|
|
|
subcell.activation = self._convert_activation(subcell.activation)
|
|
|
|
|
else:
|
|
|
|
|
subcell.has_act = True
|
|
|
|
|
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits,
|
|
|
|
|
quant_delay=self.quant_delay)
|
|
|
|
|
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
|
|
|
|
num_bits=self.act_bits,
|
|
|
|
|
quant_delay=self.act_qdelay,
|
|
|
|
|
per_channel=self.act_channel,
|
|
|
|
|
symmetric=self.act_symmetric,
|
|
|
|
|
narrow_range=self.act_range)
|
|
|
|
|
return subcell
|
|
|
|
|
|
|
|
|
|
def _convert_dense(self, subcell):
|
|
|
|
@ -219,16 +222,22 @@ class ConvertToQuantNetwork:
|
|
|
|
|
dense_inner = quant.DenseQuant(dense_inner.in_channels,
|
|
|
|
|
dense_inner.out_channels,
|
|
|
|
|
has_bias=dense_inner.has_bias,
|
|
|
|
|
quant_delay=self.quant_delay,
|
|
|
|
|
per_channel=self.per_channel,
|
|
|
|
|
num_bits=self.weight_bits)
|
|
|
|
|
num_bits=self.weight_bits,
|
|
|
|
|
quant_delay=self.weight_qdelay,
|
|
|
|
|
per_channel=self.weight_channel,
|
|
|
|
|
symmetric=self.weight_symmetric,
|
|
|
|
|
narrow_range=self.weight_range)
|
|
|
|
|
subcell.dense = dense_inner
|
|
|
|
|
if subcell.has_act and subcell.activation is not None:
|
|
|
|
|
subcell.activation = self._convert_activation(subcell.activation)
|
|
|
|
|
else:
|
|
|
|
|
subcell.has_act = True
|
|
|
|
|
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits,
|
|
|
|
|
quant_delay=self.quant_delay)
|
|
|
|
|
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
|
|
|
|
num_bits=self.act_bits,
|
|
|
|
|
quant_delay=self.act_delay,
|
|
|
|
|
per_channel=self.act_channel,
|
|
|
|
|
symmetric=self.act_symmetric,
|
|
|
|
|
narrow_range=self.act_range)
|
|
|
|
|
return subcell
|
|
|
|
|
|
|
|
|
|
def _convert_activation(self, activation):
|
|
|
|
@ -236,7 +245,11 @@ class ConvertToQuantNetwork:
|
|
|
|
|
if act_class not in _ACTIVATION_MAP:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Unsupported activation in auto Quant: ", act_class)
|
|
|
|
|
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay)
|
|
|
|
|
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
|
|
|
|
|
quant_delay=self.act_qdelay,
|
|
|
|
|
per_channel=self.act_channel,
|
|
|
|
|
symmetric=self.weight_symmetric,
|
|
|
|
|
narrow_range=self.weight_range)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExportQuantNetworkDeploy:
|
|
|
|
@ -381,32 +394,57 @@ def export_geir(network, *inputs, file_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_quant_network(network,
|
|
|
|
|
quant_delay=0,
|
|
|
|
|
bn_fold=False,
|
|
|
|
|
freeze_bn=0,
|
|
|
|
|
weight_bits=8,
|
|
|
|
|
act_bits=8,
|
|
|
|
|
per_channel=False,
|
|
|
|
|
symmetric=False,
|
|
|
|
|
narrow_range=False
|
|
|
|
|
quant_delay=(0, 0),
|
|
|
|
|
num_bits=(8, 8),
|
|
|
|
|
per_channel=(False, False),
|
|
|
|
|
symmetric=(False, False),
|
|
|
|
|
narrow_range=(False, False)
|
|
|
|
|
):
|
|
|
|
|
r"""
|
|
|
|
|
Create aware quantizaiton training network.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
network (Cell): Obtain a pipeline through network for saving graph summary.
|
|
|
|
|
quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0.
|
|
|
|
|
quant_delay (int): Number of steps after which weights and activations are quantized during
|
|
|
|
|
eval. The first element represent weights and second element represent data flow. Default: [0, 0]
|
|
|
|
|
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
|
|
|
|
|
freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0.
|
|
|
|
|
weight_bits (int): Number of bits to use for quantizing weights. Default: 8.
|
|
|
|
|
act_bits (int): Number of bits to use for quantizing activations. Default: 8.
|
|
|
|
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
|
|
|
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
|
|
|
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
|
|
|
|
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0.
|
|
|
|
|
num_bits (list of int): Number of bits to use for quantizing weights and activations. The first
|
|
|
|
|
element represent weights and second element represent data flow. Default: [8, 8]
|
|
|
|
|
per_channel (list of bool): Quantization granularity based on layer or on channel. If `True`
|
|
|
|
|
then base on per channel otherwise base on per layer. The first element represent weights
|
|
|
|
|
and second element represent data flow. Default: [False, False]
|
|
|
|
|
symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on
|
|
|
|
|
symmetric otherwise base on assymmetric. The first element represent weights and second
|
|
|
|
|
element represent data flow. Default: [False, False]
|
|
|
|
|
narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base
|
|
|
|
|
on narrow range otherwise base on off narrow range. The first element represent weights and
|
|
|
|
|
second element represent data flow. Default: [False, False]
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Cell, Network which has change to aware quantization training network.
|
|
|
|
|
Cell, Network which has change to aware quantization training network cell.
|
|
|
|
|
"""
|
|
|
|
|
net = ConvertToQuantNetwork(
|
|
|
|
|
network, quant_delay, bn_fold, freeze_bn, weight_bits, act_bits, per_channel, symmetric, narrow_range)
|
|
|
|
|
def convert2list(name, value):
|
|
|
|
|
if not isinstance(value, list) and not isinstance(value, tuple):
|
|
|
|
|
value = [value]
|
|
|
|
|
elif len(value) > 2:
|
|
|
|
|
raise ValueError("input `{}` len should less then 2".format(name))
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
quant_delay = convert2list("quant delay", quant_delay)
|
|
|
|
|
num_bits = convert2list("num bits", num_bits)
|
|
|
|
|
per_channel = convert2list("per channel", per_channel)
|
|
|
|
|
symmetric = convert2list("symmetric", symmetric)
|
|
|
|
|
narrow_range = convert2list("narrow range", narrow_range)
|
|
|
|
|
|
|
|
|
|
net = ConvertToQuantNetwork(network=network,
|
|
|
|
|
quant_delay=quant_delay,
|
|
|
|
|
bn_fold=bn_fold,
|
|
|
|
|
freeze_bn=freeze_bn,
|
|
|
|
|
num_bits=num_bits,
|
|
|
|
|
per_channel=per_channel,
|
|
|
|
|
symmetric=symmetric,
|
|
|
|
|
narrow_range=narrow_range)
|
|
|
|
|
return net.run()
|
|
|
|
|