!8638 Add Conv2dBnFoldQuantOneConv function

From: @xiaoyisd
Reviewed-by: @sanjaychan,@chenfei52
Signed-off-by: @sanjaychan
pull/8638/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 27f274f566

@ -97,7 +97,7 @@ class Conv2dBnAct(Cell):
weight_init='normal',
bias_init='zeros',
has_bn=False,
momentum=0.9,
momentum=0.997,
eps=1e-5,
activation=None,
alpha=0.2,

@ -34,6 +34,7 @@ from ...ops.operations import _quant_ops as Q
__all__ = [
'FakeQuantWithMinMaxObserver',
'Conv2dBnFoldQuantOneConv',
'Conv2dBnFoldQuant',
'Conv2dBnWithoutFoldQuant',
'Conv2dQuant',
@ -330,6 +331,220 @@ QuantConfig = namedtuple("QuantConfig", ['weight', 'activation'])
quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver)
class Conv2dBnFoldQuantOneConv(Cell):
r"""
2D convolution with BatchNormal op folded construct.
This part is a more detailed overview of Conv2d op.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
stride (int): Specifies stride for all spatial dimensions with the same value.
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding (int): Implicit paddings on both sides of the input. Default: 0.
eps (float): Parameters for BatchNormal. Default: 1e-5.
momentum (float): Parameters for BatchNormal op. Default: 0.997.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
convolution kernel. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
bias vector. Default: 'zeros'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
beta vector. Default: 'zeros'.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
gamma vector. Default: 'ones'.
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
mean vector. Default: 'zeros'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'ones'.
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
quant_config (QuantConfig): Configs the oberser types and quant configs of weight and activation. Default:
both set to default FakeQuantWithMinMaxObserver.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> qconfig = compression.quant.create_quant_config()
>>> conv2d_bnfold = nn.Conv2dBnFoldQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
>>> quant_config=qconfig)
>>> input = Tensor(np.random.randint(-2, 2, (2, 1, 3, 3)), mindspore.float32)
>>> result = conv2d_bnfold(input)
>>> result.shape
(2, 6, 2, 2)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
eps=1e-5,
momentum=0.997,
has_bias=False,
weight_init='normal',
bias_init='zeros',
beta_init='zeros',
gamma_init='ones',
mean_init='zeros',
var_init='ones',
fake=True,
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
"""Initialize Conv2dBnFoldQuant layer"""
super(Conv2dBnFoldQuantOneConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = twice(kernel_size)
self.stride = twice(stride)
self.pad_mode = pad_mode
self.padding = padding
self.dilation = twice(dilation)
self.group = group
self.eps = eps
self.momentum = momentum
self.has_bias = has_bias
self.fake = fake
self.quant_config = quant_config
self.quant_dtype = quant_dtype
self.is_gpu = context.get_context('device_target') == "GPU"
self.is_Ascend = context.get_context('device_target') == "Ascend"
if context.get_context("enable_ge"):
self.is_ge_backend = True
else:
self.is_ge_backend = False
# initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1:
Validator.check_equal_int(group, in_channels, 'group')
Validator.check_equal_int(group, out_channels, 'group')
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation)
weight_shape = [1, in_channels, *self.kernel_size]
channel_axis = 1
else:
self.conv = P.Conv2D(out_channel=out_channels,
kernel_size=self.kernel_size,
pad_mode=pad_mode,
pad=padding,
stride=self.stride,
dilation=self.dilation,
group=group)
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.bias_add = P.BiasAdd()
if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else:
self.bias = None
# initialize BatchNorm Parameter
self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
requires_grad=False)
# initialize fake ops
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6,
ema=False,
channel_axis=channel_axis,
num_channels=out_channels,
quant_dtype=quant_dtype)
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
elif self.is_gpu:
self.bn_train = P.FusedBatchNormEx(mode=1,
epsilon=self.eps,
momentum=self.momentum,
data_format=self.format)
else:
self.bn_train = P.FusedBatchNorm(mode=1,
epsilon=self.eps,
momentum=self.momentum)
self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
data_parallel_strategy = ((1,), (1,))
data_parallel_strategy_one = ((1,), ())
self.sub_mean = P.Sub().shard(data_parallel_strategy)
self.sub_var = P.Sub().shard(data_parallel_strategy)
self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
self.mul_var = P.Mul().shard(data_parallel_strategy_one)
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
self.one = Tensor(1, mstype.int32)
self.reshape = P.Reshape()
def extend_repr(self):
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation,
self.group,
self.fake, self.freeze_bn, self.momentum,
self.fake_quant_weight.quant_delay)
return s
def construct(self, x):
running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps))
scale_factor = self.gamma / running_std
weight = self.weight * scale_factor
if self.channel_axis:
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
else:
scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1))
if self.fake:
weight = self.fake_quant_weight(weight)
conv = self.conv(x, weight)
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
conv_orig = conv / scale_factor
if self.training:
if not self.is_gpu:
out, batch_mean, batch_var, _, _ = self.bn_train(conv_orig,
self.gamma,
self.beta,
None,
None)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
temp_variance = self.mul_var(mean_sub2, self.momentum)
out = F.depend(out, self.assign_sub_mean(self.moving_mean, temp_mean))
out = F.depend(out, self.assign_sub_var(self.moving_variance, temp_variance))
else:
out = self.bn_train(conv_orig,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]
else:
out = self.bn_infer(conv_orig,
self.gamma,
self.beta,
self.moving_mean,
self.moving_variance)[0]
return out
class Conv2dBnFoldQuant(Cell):
r"""
2D convolution with BatchNormal op folded construct.
@ -627,7 +842,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
channel_axis=channel_axis,
num_channels=out_channels,
quant_dtype=quant_dtype)
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=1-momentum)
def construct(self, x):
weight = self.fake_quant_weight(self.weight)

Loading…
Cancel
Save