|
|
|
@ -22,7 +22,7 @@ from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore._checkparam import check_int_positive, check_bool, twice
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
|
from mindspore._checkparam import Validator as validator, Rel
|
|
|
|
|
from mindspore.nn.cell import Cell
|
|
|
|
|
from mindspore.nn.layer.activation import get_activation
|
|
|
|
|
import mindspore.context as context
|
|
|
|
@ -207,7 +207,7 @@ class FakeQuantWithMinMaxD(Cell):
|
|
|
|
|
|
|
|
|
|
class FakeQuantWithMinMax(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max.
|
|
|
|
|
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
|
|
|
|
@ -243,8 +243,7 @@ class FakeQuantWithMinMax(Cell):
|
|
|
|
|
out_channels=1,
|
|
|
|
|
quant_delay=0,
|
|
|
|
|
symmetric=False,
|
|
|
|
|
narrow_range=False,
|
|
|
|
|
training=True):
|
|
|
|
|
narrow_range=False):
|
|
|
|
|
"""init FakeQuantWithMinMax layer"""
|
|
|
|
|
super(FakeQuantWithMinMax, self).__init__()
|
|
|
|
|
|
|
|
|
@ -258,7 +257,6 @@ class FakeQuantWithMinMax(Cell):
|
|
|
|
|
self.quant_delay = quant_delay
|
|
|
|
|
self.symmetric = symmetric
|
|
|
|
|
self.narrow_range = narrow_range
|
|
|
|
|
self.training = training
|
|
|
|
|
|
|
|
|
|
if per_channel:
|
|
|
|
|
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32)
|
|
|
|
@ -422,11 +420,13 @@ class Conv2dBatchNormQuant(Cell):
|
|
|
|
|
self.per_channel = per_channel
|
|
|
|
|
self.symmetric = symmetric
|
|
|
|
|
self.narrow_range = narrow_range
|
|
|
|
|
self.channel_axis = int(group > 1)
|
|
|
|
|
self.is_gpu = context.get_context('device_target') == "GPU"
|
|
|
|
|
|
|
|
|
|
# initialize convolution op and Parameter
|
|
|
|
|
if context.get_context('device_target') == "Ascend" and group > 1:
|
|
|
|
|
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant')
|
|
|
|
|
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant')
|
|
|
|
|
validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant')
|
|
|
|
|
validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant')
|
|
|
|
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
|
|
|
|
kernel_size=self.kernel_size,
|
|
|
|
|
pad_mode=pad_mode,
|
|
|
|
@ -472,7 +472,7 @@ class Conv2dBatchNormQuant(Cell):
|
|
|
|
|
symmetric=symmetric,
|
|
|
|
|
narrow_range=narrow_range)
|
|
|
|
|
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
|
|
|
|
|
self.correct_mul = P.CorrectionMul()
|
|
|
|
|
self.correct_mul = P.CorrectionMul(self.channel_axis)
|
|
|
|
|
if context.get_context('device_target') == "Ascend":
|
|
|
|
|
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn)
|
|
|
|
|
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0)
|
|
|
|
|