|
|
|
@ -34,6 +34,7 @@ from ...ops.operations import _quant_ops as Q
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'FakeQuantWithMinMaxObserver',
|
|
|
|
|
'Conv2dBnFoldQuantOneConv',
|
|
|
|
|
'Conv2dBnFoldQuant',
|
|
|
|
|
'Conv2dBnWithoutFoldQuant',
|
|
|
|
|
'Conv2dQuant',
|
|
|
|
@ -330,6 +331,220 @@ QuantConfig = namedtuple("QuantConfig", ['weight', 'activation'])
|
|
|
|
|
quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Conv2dBnFoldQuantOneConv(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
2D convolution with BatchNormal op folded construct.
|
|
|
|
|
|
|
|
|
|
This part is a more detailed overview of Conv2d op.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
in_channels (int): The number of input channel :math:`C_{in}`.
|
|
|
|
|
out_channels (int): The number of output channel :math:`C_{out}`.
|
|
|
|
|
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
|
|
|
|
|
stride (int): Specifies stride for all spatial dimensions with the same value.
|
|
|
|
|
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
|
|
|
|
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
|
|
|
|
eps (float): Parameters for BatchNormal. Default: 1e-5.
|
|
|
|
|
momentum (float): Parameters for BatchNormal op. Default: 0.997.
|
|
|
|
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
|
|
|
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
|
|
|
|
convolution kernel. Default: 'normal'.
|
|
|
|
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
|
|
|
|
bias vector. Default: 'zeros'.
|
|
|
|
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
|
|
|
|
beta vector. Default: 'zeros'.
|
|
|
|
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
|
|
|
|
gamma vector. Default: 'ones'.
|
|
|
|
|
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
|
|
|
|
mean vector. Default: 'zeros'.
|
|
|
|
|
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
|
|
|
|
variance vector. Default: 'ones'.
|
|
|
|
|
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
|
|
|
|
|
quant_config (QuantConfig): Configs the oberser types and quant configs of weight and activation. Default:
|
|
|
|
|
both set to default FakeQuantWithMinMaxObserver.
|
|
|
|
|
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> qconfig = compression.quant.create_quant_config()
|
|
|
|
|
>>> conv2d_bnfold = nn.Conv2dBnFoldQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
|
|
|
|
|
>>> quant_config=qconfig)
|
|
|
|
|
>>> input = Tensor(np.random.randint(-2, 2, (2, 1, 3, 3)), mindspore.float32)
|
|
|
|
|
>>> result = conv2d_bnfold(input)
|
|
|
|
|
>>> result.shape
|
|
|
|
|
(2, 6, 2, 2)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size,
|
|
|
|
|
stride=1,
|
|
|
|
|
pad_mode='same',
|
|
|
|
|
padding=0,
|
|
|
|
|
dilation=1,
|
|
|
|
|
group=1,
|
|
|
|
|
eps=1e-5,
|
|
|
|
|
momentum=0.997,
|
|
|
|
|
has_bias=False,
|
|
|
|
|
weight_init='normal',
|
|
|
|
|
bias_init='zeros',
|
|
|
|
|
beta_init='zeros',
|
|
|
|
|
gamma_init='ones',
|
|
|
|
|
mean_init='zeros',
|
|
|
|
|
var_init='ones',
|
|
|
|
|
fake=True,
|
|
|
|
|
quant_config=quant_config_default,
|
|
|
|
|
quant_dtype=QuantDtype.INT8):
|
|
|
|
|
"""Initialize Conv2dBnFoldQuant layer"""
|
|
|
|
|
super(Conv2dBnFoldQuantOneConv, self).__init__()
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
|
self.out_channels = out_channels
|
|
|
|
|
self.kernel_size = twice(kernel_size)
|
|
|
|
|
self.stride = twice(stride)
|
|
|
|
|
self.pad_mode = pad_mode
|
|
|
|
|
self.padding = padding
|
|
|
|
|
self.dilation = twice(dilation)
|
|
|
|
|
self.group = group
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.momentum = momentum
|
|
|
|
|
self.has_bias = has_bias
|
|
|
|
|
self.fake = fake
|
|
|
|
|
self.quant_config = quant_config
|
|
|
|
|
self.quant_dtype = quant_dtype
|
|
|
|
|
self.is_gpu = context.get_context('device_target') == "GPU"
|
|
|
|
|
self.is_Ascend = context.get_context('device_target') == "Ascend"
|
|
|
|
|
if context.get_context("enable_ge"):
|
|
|
|
|
self.is_ge_backend = True
|
|
|
|
|
else:
|
|
|
|
|
self.is_ge_backend = False
|
|
|
|
|
|
|
|
|
|
# initialize convolution op and Parameter
|
|
|
|
|
if context.get_context('device_target') == "Ascend" and group > 1:
|
|
|
|
|
Validator.check_equal_int(group, in_channels, 'group')
|
|
|
|
|
Validator.check_equal_int(group, out_channels, 'group')
|
|
|
|
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
|
|
|
|
kernel_size=self.kernel_size,
|
|
|
|
|
pad_mode=pad_mode,
|
|
|
|
|
pad=padding,
|
|
|
|
|
stride=self.stride,
|
|
|
|
|
dilation=self.dilation)
|
|
|
|
|
weight_shape = [1, in_channels, *self.kernel_size]
|
|
|
|
|
channel_axis = 1
|
|
|
|
|
else:
|
|
|
|
|
self.conv = P.Conv2D(out_channel=out_channels,
|
|
|
|
|
kernel_size=self.kernel_size,
|
|
|
|
|
pad_mode=pad_mode,
|
|
|
|
|
pad=padding,
|
|
|
|
|
stride=self.stride,
|
|
|
|
|
dilation=self.dilation,
|
|
|
|
|
group=group)
|
|
|
|
|
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
|
|
|
|
channel_axis = 0
|
|
|
|
|
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
|
|
|
|
self.bias_add = P.BiasAdd()
|
|
|
|
|
if Validator.check_bool(has_bias):
|
|
|
|
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
|
|
|
|
else:
|
|
|
|
|
self.bias = None
|
|
|
|
|
|
|
|
|
|
# initialize BatchNorm Parameter
|
|
|
|
|
self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
|
|
|
|
|
self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
|
|
|
|
|
self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
|
|
|
|
|
self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
|
|
|
|
|
requires_grad=False)
|
|
|
|
|
|
|
|
|
|
# initialize fake ops
|
|
|
|
|
self.fake_quant_weight = quant_config.weight(min_init=-6,
|
|
|
|
|
max_init=6,
|
|
|
|
|
ema=False,
|
|
|
|
|
channel_axis=channel_axis,
|
|
|
|
|
num_channels=out_channels,
|
|
|
|
|
quant_dtype=quant_dtype)
|
|
|
|
|
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
|
|
|
|
|
self.bn_train = P.BatchNorm(is_training=True,
|
|
|
|
|
epsilon=self.eps)
|
|
|
|
|
elif self.is_gpu:
|
|
|
|
|
self.bn_train = P.FusedBatchNormEx(mode=1,
|
|
|
|
|
epsilon=self.eps,
|
|
|
|
|
momentum=self.momentum,
|
|
|
|
|
data_format=self.format)
|
|
|
|
|
else:
|
|
|
|
|
self.bn_train = P.FusedBatchNorm(mode=1,
|
|
|
|
|
epsilon=self.eps,
|
|
|
|
|
momentum=self.momentum)
|
|
|
|
|
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
|
|
|
|
|
data_parallel_strategy = ((1,), (1,))
|
|
|
|
|
data_parallel_strategy_one = ((1,), ())
|
|
|
|
|
self.sub_mean = P.Sub().shard(data_parallel_strategy)
|
|
|
|
|
self.sub_var = P.Sub().shard(data_parallel_strategy)
|
|
|
|
|
self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
|
|
|
|
|
self.mul_var = P.Mul().shard(data_parallel_strategy_one)
|
|
|
|
|
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
|
|
|
|
|
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
|
|
|
|
|
self.one = Tensor(1, mstype.int32)
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
|
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
|
|
|
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
|
|
|
|
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
|
|
|
|
|
self.kernel_size, self.stride,
|
|
|
|
|
self.pad_mode, self.padding, self.dilation,
|
|
|
|
|
self.group,
|
|
|
|
|
self.fake, self.freeze_bn, self.momentum,
|
|
|
|
|
self.fake_quant_weight.quant_delay)
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps))
|
|
|
|
|
scale_factor = self.gamma / running_std
|
|
|
|
|
weight = self.weight * scale_factor
|
|
|
|
|
if self.channel_axis:
|
|
|
|
|
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
|
|
|
|
|
else:
|
|
|
|
|
scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1))
|
|
|
|
|
if self.fake:
|
|
|
|
|
weight = self.fake_quant_weight(weight)
|
|
|
|
|
conv = self.conv(x, weight)
|
|
|
|
|
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
|
|
|
|
|
conv_orig = conv / scale_factor
|
|
|
|
|
if self.training:
|
|
|
|
|
if not self.is_gpu:
|
|
|
|
|
out, batch_mean, batch_var, _, _ = self.bn_train(conv_orig,
|
|
|
|
|
self.gamma,
|
|
|
|
|
self.beta,
|
|
|
|
|
None,
|
|
|
|
|
None)
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
|
|
|
|
out = F.depend(out, self.assign_sub_mean(self.moving_mean, temp_mean))
|
|
|
|
|
out = F.depend(out, self.assign_sub_var(self.moving_variance, temp_variance))
|
|
|
|
|
else:
|
|
|
|
|
out = self.bn_train(conv_orig,
|
|
|
|
|
self.gamma,
|
|
|
|
|
self.beta,
|
|
|
|
|
self.moving_mean,
|
|
|
|
|
self.moving_variance)[0]
|
|
|
|
|
else:
|
|
|
|
|
out = self.bn_infer(conv_orig,
|
|
|
|
|
self.gamma,
|
|
|
|
|
self.beta,
|
|
|
|
|
self.moving_mean,
|
|
|
|
|
self.moving_variance)[0]
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Conv2dBnFoldQuant(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
2D convolution with BatchNormal op folded construct.
|
|
|
|
@ -627,7 +842,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|
|
|
|
channel_axis=channel_axis,
|
|
|
|
|
num_channels=out_channels,
|
|
|
|
|
quant_dtype=quant_dtype)
|
|
|
|
|
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)
|
|
|
|
|
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=1-momentum)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
weight = self.fake_quant_weight(self.weight)
|
|
|
|
|