|
|
|
@ -138,6 +138,7 @@ class QuantizationAwareTraining(Quantizer):
|
|
|
|
|
The first element represents weights and the second element represents data flow. Default: (False, False)
|
|
|
|
|
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only
|
|
|
|
|
support QAT. Default: OptimizeOption.QAT
|
|
|
|
|
one_conv_fold (bool): Flag to used one conv bn fold ops for simulation inference operation. Default: True.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> class LeNet5(nn.Cell):
|
|
|
|
@ -182,7 +183,8 @@ class QuantizationAwareTraining(Quantizer):
|
|
|
|
|
per_channel=(False, False),
|
|
|
|
|
symmetric=(False, False),
|
|
|
|
|
narrow_range=(False, False),
|
|
|
|
|
optimize_option=OptimizeOption.QAT):
|
|
|
|
|
optimize_option=OptimizeOption.QAT,
|
|
|
|
|
one_conv_fold=True):
|
|
|
|
|
"""Init for QuantizationAwareTraining quantizer"""
|
|
|
|
|
super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option)
|
|
|
|
|
def convert2list(name, value):
|
|
|
|
@ -210,6 +212,7 @@ class QuantizationAwareTraining(Quantizer):
|
|
|
|
|
self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric")
|
|
|
|
|
self.weight_range = Validator.check_bool(narrow_range[0], "narrow range")
|
|
|
|
|
self.act_range = Validator.check_bool(narrow_range[-1], "narrow range")
|
|
|
|
|
self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold")
|
|
|
|
|
self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv,
|
|
|
|
|
nn.DenseBnAct: self._convert_dense}
|
|
|
|
|
self.quant_config = create_quant_config(quant_delay=quant_delay,
|
|
|
|
@ -300,22 +303,39 @@ class QuantizationAwareTraining(Quantizer):
|
|
|
|
|
if subcell.has_bn:
|
|
|
|
|
if self.bn_fold:
|
|
|
|
|
bn_inner = subcell.batchnorm
|
|
|
|
|
conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
|
|
|
|
|
conv_inner.out_channels,
|
|
|
|
|
kernel_size=conv_inner.kernel_size,
|
|
|
|
|
stride=conv_inner.stride,
|
|
|
|
|
pad_mode=conv_inner.pad_mode,
|
|
|
|
|
padding=conv_inner.padding,
|
|
|
|
|
dilation=conv_inner.dilation,
|
|
|
|
|
group=conv_inner.group,
|
|
|
|
|
eps=bn_inner.eps,
|
|
|
|
|
momentum=bn_inner.momentum,
|
|
|
|
|
has_bias=conv_inner.has_bias,
|
|
|
|
|
bias_init=conv_inner.bias_init,
|
|
|
|
|
freeze_bn=self.freeze_bn,
|
|
|
|
|
quant_config=self.quant_config,
|
|
|
|
|
quant_dtype=self.weight_dtype,
|
|
|
|
|
fake=True)
|
|
|
|
|
if self.one_conv_fold:
|
|
|
|
|
conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels,
|
|
|
|
|
conv_inner.out_channels,
|
|
|
|
|
kernel_size=conv_inner.kernel_size,
|
|
|
|
|
stride=conv_inner.stride,
|
|
|
|
|
pad_mode=conv_inner.pad_mode,
|
|
|
|
|
padding=conv_inner.padding,
|
|
|
|
|
dilation=conv_inner.dilation,
|
|
|
|
|
group=conv_inner.group,
|
|
|
|
|
eps=bn_inner.eps,
|
|
|
|
|
momentum=bn_inner.momentum,
|
|
|
|
|
has_bias=conv_inner.has_bias,
|
|
|
|
|
bias_init=conv_inner.bias_init,
|
|
|
|
|
quant_config=self.quant_config,
|
|
|
|
|
quant_dtype=self.weight_dtype,
|
|
|
|
|
fake=True)
|
|
|
|
|
else:
|
|
|
|
|
conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
|
|
|
|
|
conv_inner.out_channels,
|
|
|
|
|
kernel_size=conv_inner.kernel_size,
|
|
|
|
|
stride=conv_inner.stride,
|
|
|
|
|
pad_mode=conv_inner.pad_mode,
|
|
|
|
|
padding=conv_inner.padding,
|
|
|
|
|
dilation=conv_inner.dilation,
|
|
|
|
|
group=conv_inner.group,
|
|
|
|
|
eps=bn_inner.eps,
|
|
|
|
|
momentum=bn_inner.momentum,
|
|
|
|
|
has_bias=conv_inner.has_bias,
|
|
|
|
|
bias_init=conv_inner.bias_init,
|
|
|
|
|
freeze_bn=self.freeze_bn,
|
|
|
|
|
quant_config=self.quant_config,
|
|
|
|
|
quant_dtype=self.weight_dtype,
|
|
|
|
|
fake=True)
|
|
|
|
|
# change original network BatchNormal OP parameters to quant network
|
|
|
|
|
conv_inner.gamma = subcell.batchnorm.gamma
|
|
|
|
|
conv_inner.beta = subcell.batchnorm.beta
|
|
|
|
|