|
|
|
@ -20,6 +20,13 @@ from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from scipy.stats import truncnorm
|
|
|
|
|
|
|
|
|
|
format_ = "NHWC"
|
|
|
|
|
# tranpose shape to NCHW, default init is NHWC.
|
|
|
|
|
def _trans_shape(shape, shape_format):
|
|
|
|
|
if shape_format == "NCHW":
|
|
|
|
|
return (shape[0], shape[3], shape[1], shape[2])
|
|
|
|
|
return shape
|
|
|
|
|
|
|
|
|
|
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
|
|
|
|
|
fan_in = in_channel * kernel_size * kernel_size
|
|
|
|
|
scale = 1.0
|
|
|
|
@ -37,30 +44,33 @@ def _weight_variable(shape, factor=0.01):
|
|
|
|
|
|
|
|
|
|
def _conv3x3(in_channel, out_channel, stride=1):
|
|
|
|
|
weight_shape = (out_channel, 3, 3, in_channel)
|
|
|
|
|
weight_shape = _trans_shape(weight_shape, format_)
|
|
|
|
|
weight = _weight_variable(weight_shape)
|
|
|
|
|
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
|
|
|
|
|
padding=1, pad_mode='pad', weight_init=weight, data_format="NHWC")
|
|
|
|
|
padding=1, pad_mode='pad', weight_init=weight, data_format=format_)
|
|
|
|
|
|
|
|
|
|
def _conv1x1(in_channel, out_channel, stride=1):
|
|
|
|
|
weight_shape = (out_channel, 1, 1, in_channel)
|
|
|
|
|
weight_shape = _trans_shape(weight_shape, format_)
|
|
|
|
|
weight = _weight_variable(weight_shape)
|
|
|
|
|
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
|
|
|
|
|
padding=0, pad_mode='pad', weight_init=weight, data_format="NHWC")
|
|
|
|
|
padding=0, pad_mode='pad', weight_init=weight, data_format=format_)
|
|
|
|
|
|
|
|
|
|
def _conv7x7(in_channel, out_channel, stride=1):
|
|
|
|
|
weight_shape = (out_channel, 7, 7, in_channel)
|
|
|
|
|
weight_shape = _trans_shape(weight_shape, format_)
|
|
|
|
|
weight = _weight_variable(weight_shape)
|
|
|
|
|
return nn.Conv2d(in_channel, out_channel, kernel_size=7, stride=stride,
|
|
|
|
|
padding=3, pad_mode='pad', weight_init=weight, data_format="NHWC")
|
|
|
|
|
padding=3, pad_mode='pad', weight_init=weight, data_format=format_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bn(channel):
|
|
|
|
|
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0,
|
|
|
|
|
moving_mean_init=0, moving_var_init=1, data_format="NHWC")
|
|
|
|
|
moving_mean_init=0, moving_var_init=1, data_format=format_)
|
|
|
|
|
|
|
|
|
|
def _bn_last(channel):
|
|
|
|
|
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=0, beta_init=0,
|
|
|
|
|
moving_mean_init=0, moving_var_init=1, data_format="NHWC")
|
|
|
|
|
moving_mean_init=0, moving_var_init=1, data_format=format_)
|
|
|
|
|
|
|
|
|
|
def _fc(in_channel, out_channel):
|
|
|
|
|
weight_shape = (out_channel, in_channel)
|
|
|
|
@ -165,10 +175,13 @@ class ResNet(nn.Cell):
|
|
|
|
|
|
|
|
|
|
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
|
|
|
|
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
|
|
|
|
self.conv1 = _conv7x7(4, 64, stride=2)
|
|
|
|
|
input_data_channel = 4
|
|
|
|
|
if format_ == "NCHW":
|
|
|
|
|
input_data_channel = 3
|
|
|
|
|
self.conv1 = _conv7x7(input_data_channel, 64, stride=2)
|
|
|
|
|
self.bn1 = _bn(64)
|
|
|
|
|
self.relu = P.ReLU()
|
|
|
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same", data_format="NHWC")
|
|
|
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same", data_format=format_)
|
|
|
|
|
self.layer1 = self._make_layer(block,
|
|
|
|
|
layer_nums[0],
|
|
|
|
|
in_channel=in_channels[0],
|
|
|
|
@ -190,7 +203,7 @@ class ResNet(nn.Cell):
|
|
|
|
|
out_channel=out_channels[3],
|
|
|
|
|
stride=strides[3])
|
|
|
|
|
|
|
|
|
|
self.avg_pool = P.AvgPool(7, 1, data_format="NHWC")
|
|
|
|
|
self.avg_pool = P.AvgPool(7, 1, data_format=format_)
|
|
|
|
|
self.flatten = nn.Flatten()
|
|
|
|
|
self.end_point = _fc(out_channels[3], num_classes)
|
|
|
|
|
|
|
|
|
@ -237,7 +250,7 @@ class ResNet(nn.Cell):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resnet50(class_num=1001):
|
|
|
|
|
def resnet50(class_num=1001, dtype="fp16"):
|
|
|
|
|
"""
|
|
|
|
|
Get ResNet50 neural network.
|
|
|
|
|
|
|
|
|
@ -250,6 +263,9 @@ def resnet50(class_num=1001):
|
|
|
|
|
Examples:
|
|
|
|
|
>>> net = resnet50(1001)
|
|
|
|
|
"""
|
|
|
|
|
global format_
|
|
|
|
|
if dtype == "fp32":
|
|
|
|
|
format_ = "NCHW"
|
|
|
|
|
return ResNet(ResidualBlock,
|
|
|
|
|
[3, 4, 6, 3],
|
|
|
|
|
[64, 256, 512, 1024],
|
|
|
|
|