|
|
|
@ -125,34 +125,6 @@ class _AddFakeQuantAfterSubCell(nn.Cell):
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvertToQuantNetwork:
|
|
|
|
|
"""
|
|
|
|
|
Convert network to quantization aware network
|
|
|
|
|
"""
|
|
|
|
|
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
|
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
|
|
|
|
|
self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay")
|
|
|
|
|
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
|
|
|
|
|
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
|
|
|
|
|
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
|
|
|
|
|
self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype)
|
|
|
|
|
self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype)
|
|
|
|
|
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
|
|
|
|
|
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
|
|
|
|
|
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
|
|
|
|
|
self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric")
|
|
|
|
|
self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range")
|
|
|
|
|
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
|
|
|
|
|
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
|
|
|
|
quant.DenseBnAct: self._convert_dense}
|
|
|
|
|
self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"],
|
|
|
|
|
quant_dtype=kwargs["quant_dtype"],
|
|
|
|
|
per_channel=kwargs["per_channel"],
|
|
|
|
|
symmetric=kwargs["symmetric"],
|
|
|
|
|
narrow_range=kwargs["narrow_range"])
|
|
|
|
|
|
|
|
|
|
class QuantizationAwareTraining(Quantizer):
|
|
|
|
|
r"""
|
|
|
|
|
Quantizer for quantization aware training.
|
|
|
|
@ -175,6 +147,39 @@ class QuantizationAwareTraining(Quantizer):
|
|
|
|
|
The first element represents weights and the second element represents data flow. Default: (False, False)
|
|
|
|
|
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only
|
|
|
|
|
support QAT. Default: OptimizeOption.QAT
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> class Net(nn.Cell):
|
|
|
|
|
>>> def __init__(self, num_class=10, channel=1):
|
|
|
|
|
>>> super(LeNet5, self).__init__()
|
|
|
|
|
>>> self.type = "fusion"
|
|
|
|
|
>>> self.num_class = num_class
|
|
|
|
|
>>>
|
|
|
|
|
>>> # change `nn.Conv2d` to `nn.Conv2dBnAct`
|
|
|
|
|
>>> self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
|
|
|
|
|
>>> self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
|
|
|
|
|
>>> # change `nn.Dense` to `nn.DenseBnAct`
|
|
|
|
|
>>> self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
|
|
|
|
>>> self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
|
|
|
|
>>> self.fc3 = nn.DenseBnAct(84, self.num_class)
|
|
|
|
|
>>>
|
|
|
|
|
>>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
|
|
|
>>> self.flatten = nn.Flatten()
|
|
|
|
|
>>>
|
|
|
|
|
>>> def construct(self, x):
|
|
|
|
|
>>> x = self.conv1(x)
|
|
|
|
|
>>> x = self.max_pool2d(x)
|
|
|
|
|
>>> x = self.conv2(x)
|
|
|
|
|
>>> x = self.max_pool2d(x)
|
|
|
|
|
>>> x = self.flatten(x)
|
|
|
|
|
>>> x = self.fc1(x)
|
|
|
|
|
>>> x = self.fc2(x)
|
|
|
|
|
>>> x = self.fc3(x)
|
|
|
|
|
>>> return x
|
|
|
|
|
>>>
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
|
|
|
|
|
>>> net_qat = quantizer.quantize(net)
|
|
|
|
|
"""
|
|
|
|
|
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
|
|
|
|
|
|
|
|
@ -230,6 +235,17 @@ class QuantizationAwareTraining(Quantizer):
|
|
|
|
|
return name_new
|
|
|
|
|
|
|
|
|
|
def quantize(self, network):
|
|
|
|
|
"""
|
|
|
|
|
Quant API to convert input network to a quantization aware training network
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
network (Cell): network to be quantized.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> net = Net()
|
|
|
|
|
>>> quantizer = QuantizationAwareTraining()
|
|
|
|
|
>>> net_qat = quantizer.quantize(net)
|
|
|
|
|
"""
|
|
|
|
|
support_device = ["Ascend", "GPU"]
|
|
|
|
|
if context.get_context('device_target') not in support_device:
|
|
|
|
|
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
|
|
|
|