diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index 341f9fa74e..6b164dcab7 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -16,6 +16,7 @@ import numpy as np from mindspore import log as logger +from mindspore import context from mindspore.ops import operations as P from mindspore.ops.primitive import constexpr from mindspore.common.parameter import Parameter @@ -27,7 +28,7 @@ from mindspore._checkparam import check_bool, twice, check_int_positive from mindspore._extends import cell_attr_register from ..cell import Cell -__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d', 'Conv1d', 'Conv1dTranspose'] +__all__ = ['Conv2d', 'Conv2dTranspose', 'Conv1d', 'Conv1dTranspose'] class _Conv(Cell): @@ -171,7 +172,8 @@ class Conv2d(_Conv): be greater or equal to 1 and bounded by the height and width of the input. Default: 1. group (int): Split filter into groups, `in_ channels` and `out_channels` should be - divisible by the number of groups. Default: 1. + divisible by the number of groups. If the group is equal to `in_channels` and `out_channels`, + this 2D convolution layer also can be called 2D depthwise convolution layer. Default: 1. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. It can be a Tensor, a string, an Initializer or a number. When a string is specified, @@ -211,6 +213,7 @@ class Conv2d(_Conv): bias_init='zeros'): kernel_size = twice(kernel_size) stride = twice(stride) + self._dilation = dilation dilation = twice(dilation) super(Conv2d, self).__init__( in_channels, @@ -232,10 +235,23 @@ class Conv2d(_Conv): stride=self.stride, dilation=self.dilation, group=self.group) + self._init_depthwise_conv2d() self.bias_add = P.BiasAdd() - if pad_mode not in ('valid', 'same', 'pad'): - raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' - + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') + + def _init_depthwise_conv2d(self): + """Init depthwise conv2d op""" + if context.get_context("device_target") == "Ascend" and self.group > 1: + self.dilation = self._dilation + validator.check_integer('group', self.group, self.in_channels, Rel.EQ) + validator.check_integer('group', self.group, self.out_channels, Rel.EQ) + self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1, + kernel_size=self.kernel_size, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation) + weight_shape = [1, self.in_channels, *self.kernel_size] + self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight') def construct(self, x): output = self.conv2d(x, self.weight) @@ -798,161 +814,3 @@ class Conv1dTranspose(_Conv): self.weight_init, self.bias_init) return s - - -class DepthwiseConv2d(Cell): - r""" - 2D depthwise convolution layer. - - Applies a 2D depthwise convolution over an input tensor which is typically of shape: - math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number. - For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as: - - .. math:: - - out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j, - - where :math:`ccor` is the cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges - from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to the :math:`i`-th channel of the :math:`j`-th - filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice - of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and - :math:`\text{ks_w}` are the height and width of the convolution kernel. The full kernel has shape - :math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number - to split the input in the channel dimension. - - If the 'pad_mode' is set to be "valid", the output height and width will be - :math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} - - (\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and - :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} - - (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively. - - The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition - `_. - - Args: - in_channels (int): The number of input channel :math:`C_{in}`. - out_channels (int): The number of output channel :math:`C_{out}`. - kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height - and width of the 2D convolution window. Single int means the value is for both the height and the width of - the kernel. A tuple of 2 ints means the first value is for the height and the other is for the - width of the kernel. - stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents - the height and width of movement are both strides, or a tuple of two int numbers that - represent height and width of movement respectively. Default: 1. - pad_mode (str): Specifies padding mode. The optional values are - "same", "valid", "pad". Default: "same". - - - same: Adopts the way of completion. The height and width of the output will be the same as - the input. The total number of padding will be calculated in horizontal and vertical - directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the - last extra padding will be done from the bottom and the right side. If this mode is set, `padding` - must be 0. - - - valid: Adopts the way of discarding. The possible largest height and width of output will be returned - without padding. Extra pixels will be discarded. If this mode is set, `padding` - must be 0. - - - pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input - Tensor borders. `padding` should be greater than or equal to 0. - - padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input. If `padding` is one integer, - the paddings of top, bottom, left and right are the same, equal to padding. If `padding` is a tuple - with four integers, the paddings of top, bottom, left and right will be equal to padding[0], - padding[1], padding[2], and padding[3] accordingly. Default: 0. - dilation (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the dilation rate - to use for dilated convolution. If set to be :math:`k > 1`, there will - be :math:`k - 1` pixels skipped for each sampling location. Its value should - be greater than or equal to 1 and bounded by the height and width of the - input. Default: 1. - group (int): Split filter into groups, `in_ channels` and `out_channels` should be - divisible by the number of groups. If 'group' is None, it will be set as the value of 'in_channels' - has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. - weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. - It can be a Tensor, a string, an Initializer or a number. When a string is specified, - values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well - as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' - and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of - Initializer for more details. Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible - Initializer and string are the same as 'weight_init'. Refer to the values of - Initializer for more details. Default: 'zeros'. - - Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. - - Examples: - >>> net = nn.DepthwiseConv2d(240, 240, 4, group=None, has_bias=False, weight_init='normal') - >>> input = Tensor(np.ones([1, 240, 1024, 640]), mindspore.float32) - >>> net(input).shape - (1, 240, 1024, 640) - """ - - def __init__(self, - in_channels, - out_channels, - kernel_size, - group, - stride=1, - pad_mode='same', - padding=0, - dilation=1, - has_bias=False, - weight_init='normal', - bias_init='zeros'): - super(DepthwiseConv2d, self).__init__() - self.kernel_size = twice(kernel_size) - self.stride = twice(stride) - self.dilation = twice(dilation) - self.in_channels = check_int_positive(in_channels) - self.out_channels = check_int_positive(out_channels) - if group is None: - group = in_channels - validator.check_integer('group', group, in_channels, Rel.EQ) - validator.check_integer('group', group, out_channels, Rel.EQ) - validator.check_integer('group', group, 1, Rel.GE) - self.pad_mode = pad_mode - self.dilation = dilation - self.group = group - self.has_bias = has_bias - self.weight_init = weight_init - self.bias_init = bias_init - Validator.check_value_type('padding', padding, (int, tuple), self.cls_name) - if isinstance(padding, tuple): - Validator.check_integer('padding size', len(padding), 4, Rel.EQ, self.cls_name) - self.padding = padding - self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, - kernel_size=self.kernel_size, - pad_mode=self.pad_mode, - pad=self.padding, - stride=self.stride, - dilation=self.dilation) - self.bias_add = P.BiasAdd() - weight_shape = [1, in_channels, *self.kernel_size] - self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') - if check_bool(has_bias): - self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') - else: - if bias_init != 'zeros': - logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.") - self.bias = None - - def construct(self, x): - out = self.conv(x, self.weight) - if self.has_bias: - out = self.bias_add(out, self.bias) - return out - - def extend_repr(self): - s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ - 'pad_mode={}, padding={}, dilation={}, group={}, ' \ - 'has_bias={}, weight_init={}, bias_init={}'.format( - self.in_channels, self.out_channels, self.kernel_size, self.stride, - self.pad_mode, self.padding, self.dilation, self.group, - self.has_bias, self.weight_init, self.bias_init) - - if self.has_bias: - s += ', bias={}'.format(self.bias) - return s diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index a70b9058d4..1c28b02195 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -122,31 +122,17 @@ class Conv2dBnAct(Cell): after_fake=True): super(Conv2dBnAct, self).__init__() - if context.get_context('device_target') == "Ascend" and group > 1: - self.conv = conv.DepthwiseConv2d(in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - pad_mode=pad_mode, - padding=padding, - dilation=dilation, - group=group, - has_bias=has_bias, - weight_init=weight_init, - bias_init=bias_init) - else: - self.conv = conv.Conv2d(in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - pad_mode=pad_mode, - padding=padding, - dilation=dilation, - group=group, - has_bias=has_bias, - weight_init=weight_init, - bias_init=bias_init) - + self.conv = conv.Conv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + group=group, + has_bias=has_bias, + weight_init=weight_init, + bias_init=bias_init) self.has_bn = validator.check_bool("has_bn", has_bn) self.has_act = activation is not None self.after_fake = after_fake diff --git a/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py b/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py index f0fb36aeb0..b214510350 100644 --- a/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py +++ b/model_zoo/official/cv/mobilenetv2/src/mobilenetV2.py @@ -17,8 +17,7 @@ import numpy as np import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.ops.operations import TensorAdd -from mindspore import Parameter, Tensor -from mindspore.common.initializer import initializer +from mindspore import Tensor __all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2'] @@ -55,52 +54,6 @@ class GlobalAvgPooling(nn.Cell): return x -class DepthwiseConv(nn.Cell): - """ - Depthwise Convolution warpper definition. - - Args: - in_planes (int): Input channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - pad_mode (str): pad mode in (pad, same, valid) - channel_multiplier (int): Output channel multiplier - has_bias (bool): has bias or not - - Returns: - Tensor, output tensor. - - Examples: - >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) - """ - - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthwiseConv, self).__init__() - self.has_bias = has_bias - self.in_channels = in_planes - self.channel_multiplier = channel_multiplier - self.out_channels = in_planes * channel_multiplier - self.kernel_size = (kernel_size, kernel_size) - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, - kernel_size=self.kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - weight_shape = [channel_multiplier, in_planes, *self.kernel_size] - self.weight = Parameter(initializer('ones', weight_shape), name='weight') - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape), name='bias') - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output - - class ConvBNReLU(nn.Cell): """ Convolution/Depthwise fused with Batchnorm and ReLU block definition. @@ -122,16 +75,14 @@ class ConvBNReLU(nn.Cell): def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): super(ConvBNReLU, self).__init__() padding = (kernel_size - 1) // 2 + in_channels = in_planes + out_channels = out_planes if groups == 1: - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) + conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding) else: - if platform in ("CPU", "GPU"): - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, group=in_planes, pad_mode='pad', \ - padding=padding) - elif platform == "Ascend": - conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) - else: - raise ValueError("Unsupported Device, only support CPU, GPU and Ascend.") + out_channels = in_planes + conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', + padding=padding, group=in_channels) layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] self.features = nn.SequentialCell(layers) @@ -188,6 +139,7 @@ class InvertedResidual(nn.Cell): return self.add(identity, x) return x + class MobileNetV2Backbone(nn.Cell): """ MobileNetV2 architecture. @@ -258,7 +210,7 @@ class MobileNetV2Backbone(nn.Cell): """ self.init_parameters_data() for _, m in self.cells_and_names(): - if isinstance(m, (nn.Conv2d, DepthwiseConv)): + if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.data.shape).astype("float32"))) @@ -275,6 +227,7 @@ class MobileNetV2Backbone(nn.Cell): def get_features(self): return self.features + class MobileNetV2Head(nn.Cell): """ MobileNetV2 architecture. @@ -325,6 +278,7 @@ class MobileNetV2Head(nn.Cell): def get_head(self): return self.head + class MobileNetV2(nn.Cell): """ MobileNetV2 architecture. @@ -353,6 +307,7 @@ class MobileNetV2(nn.Cell): x = self.head(x) return x + class MobileNetV2Combine(nn.Cell): """ MobileNetV2 architecture. @@ -380,5 +335,6 @@ class MobileNetV2Combine(nn.Cell): x = self.head(x) return x + def mobilenet_v2(backbone, head): return MobileNetV2Combine(backbone, head) diff --git a/model_zoo/official/cv/ssd/src/ssd.py b/model_zoo/official/cv/ssd/src/ssd.py index cee0d5817a..89d85887d6 100644 --- a/model_zoo/official/cv/ssd/src/ssd.py +++ b/model_zoo/official/cv/ssd/src/ssd.py @@ -18,14 +18,13 @@ import mindspore.common.dtype as mstype import mindspore as ms import mindspore.nn as nn -from mindspore import Parameter, context, Tensor +from mindspore import context, Tensor from mindspore.context import ParallelMode from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.communication.management import get_group_size from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C -from mindspore.common.initializer import initializer def _make_divisible(v, divisor, min_value=None): @@ -50,7 +49,10 @@ def _bn(channel): def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0): - depthwise_conv = DepthwiseConv(in_channel, kernel_size, stride, pad_mode='same', pad=pad) + in_channels = in_channel + out_channels = in_channel + depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', + padding=pad, group=in_channels) conv = _conv2d(in_channel, out_channel, kernel_size=1) return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv]) @@ -75,11 +77,14 @@ class ConvBNReLU(nn.Cell): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): super(ConvBNReLU, self).__init__() padding = 0 + in_channels = in_planes + out_channels = out_planes if groups == 1: - conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', - padding=padding) + conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding) else: - conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='same', pad=padding) + out_channels = in_planes + conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', + padding=padding, group=in_channels) layers = [conv, _bn(out_planes), nn.ReLU6()] self.features = nn.SequentialCell(layers) @@ -88,52 +93,6 @@ class ConvBNReLU(nn.Cell): return output -class DepthwiseConv(nn.Cell): - """ - Depthwise Convolution warpper definition. - - Args: - in_planes (int): Input channel. - kernel_size (int): Input kernel size. - stride (int): Stride size. - pad_mode (str): pad mode in (pad, same, valid) - channel_multiplier (int): Output channel multiplier - has_bias (bool): has bias or not - - Returns: - Tensor, output tensor. - - Examples: - >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) - """ - - def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): - super(DepthwiseConv, self).__init__() - self.has_bias = has_bias - self.in_channels = in_planes - self.channel_multiplier = channel_multiplier - self.out_channels = in_planes * channel_multiplier - self.kernel_size = (kernel_size, kernel_size) - self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, - kernel_size=self.kernel_size, - stride=stride, pad_mode=pad_mode, pad=pad) - self.bias_add = P.BiasAdd() - weight_shape = [channel_multiplier, in_planes, *self.kernel_size] - self.weight = Parameter(initializer('ones', weight_shape), name='weight') - - if has_bias: - bias_shape = [channel_multiplier * in_planes] - self.bias = Parameter(initializer('zeros', bias_shape), name='bias') - else: - self.bias = None - - def construct(self, x): - output = self.depthwise_conv(x, self.weight) - if self.has_bias: - output = self.bias_add(output, self.bias) - return output - - class InvertedResidual(nn.Cell): """ Residual block definition.