!5916 unify Conv2d and DepthwiseConv2d

Merge pull request !5916 from caozhou/unify_conv2d_depthwise2d
pull/5916/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c455c344a9

@ -16,6 +16,7 @@
import numpy as np import numpy as np
from mindspore import log as logger from mindspore import log as logger
from mindspore import context
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
@ -27,7 +28,7 @@ from mindspore._checkparam import check_bool, twice, check_int_positive
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d', 'Conv1d', 'Conv1dTranspose'] __all__ = ['Conv2d', 'Conv2dTranspose', 'Conv1d', 'Conv1dTranspose']
class _Conv(Cell): class _Conv(Cell):
@ -171,7 +172,8 @@ class Conv2d(_Conv):
be greater or equal to 1 and bounded by the height and width of the be greater or equal to 1 and bounded by the height and width of the
input. Default: 1. input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1. divisible by the number of groups. If the group is equal to `in_channels` and `out_channels`,
this 2D convolution layer also can be called 2D depthwise convolution layer. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a number. When a string is specified, It can be a Tensor, a string, an Initializer or a number. When a string is specified,
@ -211,6 +213,7 @@ class Conv2d(_Conv):
bias_init='zeros'): bias_init='zeros'):
kernel_size = twice(kernel_size) kernel_size = twice(kernel_size)
stride = twice(stride) stride = twice(stride)
self._dilation = dilation
dilation = twice(dilation) dilation = twice(dilation)
super(Conv2d, self).__init__( super(Conv2d, self).__init__(
in_channels, in_channels,
@ -232,10 +235,23 @@ class Conv2d(_Conv):
stride=self.stride, stride=self.stride,
dilation=self.dilation, dilation=self.dilation,
group=self.group) group=self.group)
self._init_depthwise_conv2d()
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
if pad_mode not in ('valid', 'same', 'pad'):
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' def _init_depthwise_conv2d(self):
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') """Init depthwise conv2d op"""
if context.get_context("device_target") == "Ascend" and self.group > 1:
self.dilation = self._dilation
validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation)
weight_shape = [1, self.in_channels, *self.kernel_size]
self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight')
def construct(self, x): def construct(self, x):
output = self.conv2d(x, self.weight) output = self.conv2d(x, self.weight)
@ -798,161 +814,3 @@ class Conv1dTranspose(_Conv):
self.weight_init, self.weight_init,
self.bias_init) self.bias_init)
return s return s
class DepthwiseConv2d(Cell):
r"""
2D depthwise convolution layer.
Applies a 2D depthwise convolution over an input tensor which is typically of shape:
math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
.. math::
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
where :math:`ccor` is the cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to the :math:`i`-th channel of the :math:`j`-th
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
:math:`\text{ks_w}` are the height and width of the convolution kernel. The full kernel has shape
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
to split the input in the channel dimension.
If the 'pad_mode' is set to be "valid", the output height and width will be
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value is for both the height and the width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: Adopts the way of completion. The height and width of the output will be the same as
the input. The total number of padding will be calculated in horizontal and vertical
directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
must be 0.
- valid: Adopts the way of discarding. The possible largest height and width of output will be returned
without padding. Extra pixels will be discarded. If this mode is set, `padding`
must be 0.
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
Tensor borders. `padding` should be greater than or equal to 0.
padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input. If `padding` is one integer,
the paddings of top, bottom, left and right are the same, equal to padding. If `padding` is a tuple
with four integers, the paddings of top, bottom, left and right will be equal to padding[0],
padding[1], padding[2], and padding[3] accordingly. Default: 0.
dilation (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the dilation rate
to use for dilated convolution. If set to be :math:`k > 1`, there will
be :math:`k - 1` pixels skipped for each sampling location. Its value should
be greater than or equal to 1 and bounded by the height and width of the
input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. If 'group' is None, it will be set as the value of 'in_channels'
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = nn.DepthwiseConv2d(240, 240, 4, group=None, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 240, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
group,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
has_bias=False,
weight_init='normal',
bias_init='zeros'):
super(DepthwiseConv2d, self).__init__()
self.kernel_size = twice(kernel_size)
self.stride = twice(stride)
self.dilation = twice(dilation)
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
if group is None:
group = in_channels
validator.check_integer('group', group, in_channels, Rel.EQ)
validator.check_integer('group', group, out_channels, Rel.EQ)
validator.check_integer('group', group, 1, Rel.GE)
self.pad_mode = pad_mode
self.dilation = dilation
self.group = group
self.has_bias = has_bias
self.weight_init = weight_init
self.bias_init = bias_init
Validator.check_value_type('padding', padding, (int, tuple), self.cls_name)
if isinstance(padding, tuple):
Validator.check_integer('padding size', len(padding), 4, Rel.EQ, self.cls_name)
self.padding = padding
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation)
self.bias_add = P.BiasAdd()
weight_shape = [1, in_channels, *self.kernel_size]
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
if check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else:
if bias_init != 'zeros':
logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.")
self.bias = None
def construct(self, x):
out = self.conv(x, self.weight)
if self.has_bias:
out = self.bias_add(out, self.bias)
return out
def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, weight_init={}, bias_init={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.weight_init, self.bias_init)
if self.has_bias:
s += ', bias={}'.format(self.bias)
return s

@ -122,31 +122,17 @@ class Conv2dBnAct(Cell):
after_fake=True): after_fake=True):
super(Conv2dBnAct, self).__init__() super(Conv2dBnAct, self).__init__()
if context.get_context('device_target') == "Ascend" and group > 1: self.conv = conv.Conv2d(in_channels,
self.conv = conv.DepthwiseConv2d(in_channels, out_channels,
out_channels, kernel_size=kernel_size,
kernel_size=kernel_size, stride=stride,
stride=stride, pad_mode=pad_mode,
pad_mode=pad_mode, padding=padding,
padding=padding, dilation=dilation,
dilation=dilation, group=group,
group=group, has_bias=has_bias,
has_bias=has_bias, weight_init=weight_init,
weight_init=weight_init, bias_init=bias_init)
bias_init=bias_init)
else:
self.conv = conv.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
self.has_bn = validator.check_bool("has_bn", has_bn) self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None self.has_act = activation is not None
self.after_fake = after_fake self.after_fake = after_fake

@ -17,8 +17,7 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import TensorAdd from mindspore.ops.operations import TensorAdd
from mindspore import Parameter, Tensor from mindspore import Tensor
from mindspore.common.initializer import initializer
__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2'] __all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']
@ -55,52 +54,6 @@ class GlobalAvgPooling(nn.Cell):
return x return x
class DepthwiseConv(nn.Cell):
"""
Depthwise Convolution warpper definition.
Args:
in_planes (int): Input channel.
kernel_size (int): Input kernel size.
stride (int): Stride size.
pad_mode (str): pad mode in (pad, same, valid)
channel_multiplier (int): Output channel multiplier
has_bias (bool): has bias or not
Returns:
Tensor, output tensor.
Examples:
>>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1)
"""
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
super(DepthwiseConv, self).__init__()
self.has_bias = has_bias
self.in_channels = in_planes
self.channel_multiplier = channel_multiplier
self.out_channels = in_planes * channel_multiplier
self.kernel_size = (kernel_size, kernel_size)
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
kernel_size=self.kernel_size,
stride=stride, pad_mode=pad_mode, pad=pad)
self.bias_add = P.BiasAdd()
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
if has_bias:
bias_shape = [channel_multiplier * in_planes]
self.bias = Parameter(initializer('zeros', bias_shape), name='bias')
else:
self.bias = None
def construct(self, x):
output = self.depthwise_conv(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
class ConvBNReLU(nn.Cell): class ConvBNReLU(nn.Cell):
""" """
Convolution/Depthwise fused with Batchnorm and ReLU block definition. Convolution/Depthwise fused with Batchnorm and ReLU block definition.
@ -122,16 +75,14 @@ class ConvBNReLU(nn.Cell):
def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1): def __init__(self, platform, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
in_channels = in_planes
out_channels = out_planes
if groups == 1: if groups == 1:
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding) conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding)
else: else:
if platform in ("CPU", "GPU"): out_channels = in_planes
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, group=in_planes, pad_mode='pad', \ conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad',
padding=padding) padding=padding, group=in_channels)
elif platform == "Ascend":
conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding)
else:
raise ValueError("Unsupported Device, only support CPU, GPU and Ascend.")
layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]
self.features = nn.SequentialCell(layers) self.features = nn.SequentialCell(layers)
@ -188,6 +139,7 @@ class InvertedResidual(nn.Cell):
return self.add(identity, x) return self.add(identity, x)
return x return x
class MobileNetV2Backbone(nn.Cell): class MobileNetV2Backbone(nn.Cell):
""" """
MobileNetV2 architecture. MobileNetV2 architecture.
@ -258,7 +210,7 @@ class MobileNetV2Backbone(nn.Cell):
""" """
self.init_parameters_data() self.init_parameters_data()
for _, m in self.cells_and_names(): for _, m in self.cells_and_names():
if isinstance(m, (nn.Conv2d, DepthwiseConv)): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n), m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape).astype("float32"))) m.weight.data.shape).astype("float32")))
@ -275,6 +227,7 @@ class MobileNetV2Backbone(nn.Cell):
def get_features(self): def get_features(self):
return self.features return self.features
class MobileNetV2Head(nn.Cell): class MobileNetV2Head(nn.Cell):
""" """
MobileNetV2 architecture. MobileNetV2 architecture.
@ -325,6 +278,7 @@ class MobileNetV2Head(nn.Cell):
def get_head(self): def get_head(self):
return self.head return self.head
class MobileNetV2(nn.Cell): class MobileNetV2(nn.Cell):
""" """
MobileNetV2 architecture. MobileNetV2 architecture.
@ -353,6 +307,7 @@ class MobileNetV2(nn.Cell):
x = self.head(x) x = self.head(x)
return x return x
class MobileNetV2Combine(nn.Cell): class MobileNetV2Combine(nn.Cell):
""" """
MobileNetV2 architecture. MobileNetV2 architecture.
@ -380,5 +335,6 @@ class MobileNetV2Combine(nn.Cell):
x = self.head(x) x = self.head(x)
return x return x
def mobilenet_v2(backbone, head): def mobilenet_v2(backbone, head):
return MobileNetV2Combine(backbone, head) return MobileNetV2Combine(backbone, head)

@ -18,14 +18,13 @@
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Parameter, context, Tensor from mindspore import context, Tensor
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.common.initializer import initializer
def _make_divisible(v, divisor, min_value=None): def _make_divisible(v, divisor, min_value=None):
@ -50,7 +49,10 @@ def _bn(channel):
def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0): def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0):
depthwise_conv = DepthwiseConv(in_channel, kernel_size, stride, pad_mode='same', pad=pad) in_channels = in_channel
out_channels = in_channel
depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same',
padding=pad, group=in_channels)
conv = _conv2d(in_channel, out_channel, kernel_size=1) conv = _conv2d(in_channel, out_channel, kernel_size=1)
return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv]) return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv])
@ -75,11 +77,14 @@ class ConvBNReLU(nn.Cell):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
padding = 0 padding = 0
in_channels = in_planes
out_channels = out_planes
if groups == 1: if groups == 1:
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same', conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding)
padding=padding)
else: else:
conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='same', pad=padding) out_channels = in_planes
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same',
padding=padding, group=in_channels)
layers = [conv, _bn(out_planes), nn.ReLU6()] layers = [conv, _bn(out_planes), nn.ReLU6()]
self.features = nn.SequentialCell(layers) self.features = nn.SequentialCell(layers)
@ -88,52 +93,6 @@ class ConvBNReLU(nn.Cell):
return output return output
class DepthwiseConv(nn.Cell):
"""
Depthwise Convolution warpper definition.
Args:
in_planes (int): Input channel.
kernel_size (int): Input kernel size.
stride (int): Stride size.
pad_mode (str): pad mode in (pad, same, valid)
channel_multiplier (int): Output channel multiplier
has_bias (bool): has bias or not
Returns:
Tensor, output tensor.
Examples:
>>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1)
"""
def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
super(DepthwiseConv, self).__init__()
self.has_bias = has_bias
self.in_channels = in_planes
self.channel_multiplier = channel_multiplier
self.out_channels = in_planes * channel_multiplier
self.kernel_size = (kernel_size, kernel_size)
self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
kernel_size=self.kernel_size,
stride=stride, pad_mode=pad_mode, pad=pad)
self.bias_add = P.BiasAdd()
weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
self.weight = Parameter(initializer('ones', weight_shape), name='weight')
if has_bias:
bias_shape = [channel_multiplier * in_planes]
self.bias = Parameter(initializer('zeros', bias_shape), name='bias')
else:
self.bias = None
def construct(self, x):
output = self.depthwise_conv(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
class InvertedResidual(nn.Cell): class InvertedResidual(nn.Cell):
""" """
Residual block definition. Residual block definition.

Loading…
Cancel
Save