# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ import mindspore.nn as nn import mindspore.ops.operations as P class ShuffleV2Block(nn.Cell): def __init__(self, inp, oup, mid_channels, *, ksize, stride): super(ShuffleV2Block, self).__init__() self.stride = stride ##assert stride in [1, 2] self.mid_channels = mid_channels self.ksize = ksize pad = ksize // 2 self.pad = pad self.inp = inp outputs = oup - inp branch_main = [ # pw nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, pad_mode='pad', padding=0, has_bias=False), nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), nn.ReLU(), # dw nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride, pad_mode='pad', padding=pad, group=mid_channels, has_bias=False), nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), # pw-linear nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, pad_mode='pad', padding=0, has_bias=False), nn.BatchNorm2d(num_features=outputs, momentum=0.9), nn.ReLU(), ] self.branch_main = nn.SequentialCell(branch_main) if stride == 2: branch_proj = [ # dw nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride, pad_mode='pad', padding=pad, group=inp, has_bias=False), nn.BatchNorm2d(num_features=inp, momentum=0.9), # pw-linear nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1, pad_mode='pad', padding=0, has_bias=False), nn.BatchNorm2d(num_features=inp, momentum=0.9), nn.ReLU(), ] self.branch_proj = nn.SequentialCell(branch_proj) else: self.branch_proj = None def construct(self, old_x): if self.stride == 1: x_proj, x = self.channel_shuffle(old_x) return P.Concat(1)((x_proj, self.branch_main(x))) if self.stride == 2: x_proj = old_x x = old_x return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x))) return None def channel_shuffle(self, x): batchsize, num_channels, height, width = P.Shape()(x) ##assert (num_channels % 4 == 0) x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,)) x = P.Transpose()(x, (1, 0, 2,)) x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,)) return x[0], x[1]