@ -18,8 +18,9 @@ import numpy as np
import mindspore . nn as nn
import mindspore . nn as nn
from mindspore . common . tensor import Tensor
from mindspore . common . tensor import Tensor
from mindspore . ops import operations as P
from mindspore . ops import operations as P
from mindspore import context
from src . thor_layer import Conv2d_Thor , Dense_Thor
from src . thor_layer import Conv2d_Thor , Dense_Thor , Conv2d_Thor_GPU , Dense_Thor_GPU
def calculate_gain ( nonlinearity , param = None ) :
def calculate_gain ( nonlinearity , param = None ) :
@ -81,7 +82,7 @@ def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
return np . random . normal ( 0 , std , size = inputs_shape ) . astype ( np . float32 )
return np . random . normal ( 0 , std , size = inputs_shape ) . astype ( np . float32 )
def kaiming_uniform ( inputs_shape , a = 0 , mode = ' fan_in ' , nonlinearity = ' leaky_relu ' ) :
def kaiming_uniform ( inputs_shape , a = 0. , mode = ' fan_in ' , nonlinearity = ' leaky_relu ' ) :
fan = _calculate_correct_fan ( inputs_shape , mode )
fan = _calculate_correct_fan ( inputs_shape , mode )
gain = calculate_gain ( nonlinearity , a )
gain = calculate_gain ( nonlinearity , a )
std = gain / math . sqrt ( fan )
std = gain / math . sqrt ( fan )
@ -89,28 +90,51 @@ def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu')
return np . random . uniform ( - bound , bound , size = inputs_shape ) . astype ( np . float32 )
return np . random . uniform ( - bound , bound , size = inputs_shape ) . astype ( np . float32 )
def _conv3x3 ( in_channel , out_channel , stride = 1 , damping = 0.03 , loss_scale = 1 , frequency = 278 ) :
def _weight_variable ( shape , factor = 0.01 ) :
init_value = np . random . randn ( * shape ) . astype ( np . float32 ) * factor
return Tensor ( init_value )
def _conv3x3 ( in_channel , out_channel , stride = 1 , damping = 0.03 , loss_scale = 1 , frequency = 278 , batch_size = 32 ) :
weight_shape = ( out_channel , in_channel , 3 , 3 )
weight_shape = ( out_channel , in_channel , 3 , 3 )
weight = Tensor ( kaiming_normal ( weight_shape , mode = " fan_out " , nonlinearity = ' relu ' ) )
weight = Tensor ( kaiming_normal ( weight_shape , mode = " fan_out " , nonlinearity = ' relu ' ) )
return Conv2d_Thor ( in_channel , out_channel ,
if context . get_context ( ' device_target ' ) == " Ascend " :
kernel_size = 3 , stride = stride , padding = 0 , pad_mode = ' same ' , weight_init = weight ,
layer = Conv2d_Thor ( in_channel , out_channel ,
damping = damping , loss_scale = loss_scale , frequency = frequency )
kernel_size = 3 , stride = stride , padding = 0 , pad_mode = ' same ' , weight_init = weight ,
damping = damping , loss_scale = loss_scale , frequency = frequency , batch_size = batch_size )
else :
layer = Conv2d_Thor_GPU ( in_channel , out_channel ,
kernel_size = 3 , stride = stride , padding = 0 , pad_mode = ' same ' , weight_init = weight ,
damping = damping , loss_scale = loss_scale , frequency = frequency , batch_size = batch_size )
return layer
def _conv1x1 ( in_channel , out_channel , stride = 1 , damping = 0.03 , loss_scale = 1 , frequency = 278 ) :
def _conv1x1 ( in_channel , out_channel , stride = 1 , damping = 0.03 , loss_scale = 1 , frequency = 278 , batch_size = 32 ):
weight_shape = ( out_channel , in_channel , 1 , 1 )
weight_shape = ( out_channel , in_channel , 1 , 1 )
weight = Tensor ( kaiming_normal ( weight_shape , mode = " fan_out " , nonlinearity = ' relu ' ) )
weight = Tensor ( kaiming_normal ( weight_shape , mode = " fan_out " , nonlinearity = ' relu ' ) )
return Conv2d_Thor ( in_channel , out_channel ,
if context . get_context ( ' device_target ' ) == " Ascend " :
kernel_size = 1 , stride = stride , padding = 0 , pad_mode = ' same ' , weight_init = weight ,
layer = Conv2d_Thor ( in_channel , out_channel ,
damping = damping , loss_scale = loss_scale , frequency = frequency )
kernel_size = 1 , stride = stride , padding = 0 , pad_mode = ' same ' , weight_init = weight ,
damping = damping , loss_scale = loss_scale , frequency = frequency , batch_size = batch_size )
else :
layer = Conv2d_Thor_GPU ( in_channel , out_channel ,
kernel_size = 1 , stride = stride , padding = 0 , pad_mode = ' same ' , weight_init = weight ,
damping = damping , loss_scale = loss_scale , frequency = frequency , batch_size = batch_size )
return layer
def _conv7x7 ( in_channel , out_channel , stride = 1 , damping = 0.03 , loss_scale = 1 , frequency = 278 ) :
def _conv7x7 ( in_channel , out_channel , stride = 1 , damping = 0.03 , loss_scale = 1 , frequency = 278 , batch_size = 32 ):
weight_shape = ( out_channel , in_channel , 7 , 7 )
weight_shape = ( out_channel , in_channel , 7 , 7 )
weight = Tensor ( kaiming_normal ( weight_shape , mode = " fan_out " , nonlinearity = ' relu ' ) )
weight = Tensor ( kaiming_normal ( weight_shape , mode = " fan_out " , nonlinearity = ' relu ' ) )
return Conv2d_Thor ( in_channel , out_channel ,
if context . get_context ( ' device_target ' ) == " Ascend " :
kernel_size = 7 , stride = stride , padding = 0 , pad_mode = ' same ' , weight_init = weight ,
layer = Conv2d_Thor ( in_channel , out_channel ,
damping = damping , loss_scale = loss_scale , frequency = frequency )
kernel_size = 7 , stride = stride , padding = 0 , pad_mode = ' same ' , weight_init = weight ,
damping = damping , loss_scale = loss_scale , frequency = frequency , batch_size = batch_size )
else :
layer = Conv2d_Thor_GPU ( in_channel , out_channel ,
kernel_size = 7 , stride = stride , padding = 0 , pad_mode = ' same ' , weight_init = weight ,
damping = damping , loss_scale = loss_scale , frequency = frequency , batch_size = batch_size )
return layer
def _bn ( channel ) :
def _bn ( channel ) :
@ -120,14 +144,21 @@ def _bn(channel):
def _bn_last ( channel ) :
def _bn_last ( channel ) :
return nn . BatchNorm2d ( channel , eps = 1e-4 , momentum = 0.9 ,
return nn . BatchNorm2d ( channel , eps = 1e-4 , momentum = 0.9 ,
gamma_init = 1 , beta_init = 0 , moving_mean_init = 0 , moving_var_init = 1 )
gamma_init = 0 , beta_init = 0 , moving_mean_init = 0 , moving_var_init = 1 )
def _fc ( in_channel , out_channel , damping , loss_scale , frequency ):
def _fc ( in_channel , out_channel , damping , loss_scale , frequency , batch_size = 32 ):
weight_shape = ( out_channel , in_channel )
weight_shape = ( out_channel , in_channel )
weight = Tensor ( kaiming_uniform ( weight_shape , a = math . sqrt ( 5 ) ) )
weight = Tensor ( kaiming_uniform ( weight_shape , a = math . sqrt ( 5 ) ) )
return Dense_Thor ( in_channel , out_channel , has_bias = False , weight_init = weight ,
if context . get_context ( ' device_target ' ) == " Ascend " :
bias_init = 0 , damping = damping , loss_scale = loss_scale , frequency = frequency )
layer = Dense_Thor ( in_channel , out_channel , has_bias = False , weight_init = weight ,
bias_init = 0 , damping = damping , loss_scale = loss_scale , frequency = frequency ,
batch_size = batch_size )
else :
layer = Dense_Thor_GPU ( in_channel , out_channel , has_bias = False , weight_init = weight ,
bias_init = 0 , damping = damping , loss_scale = loss_scale , frequency = frequency ,
batch_size = batch_size )
return layer
class ResidualBlock ( nn . Cell ) :
class ResidualBlock ( nn . Cell ) :
@ -153,20 +184,21 @@ class ResidualBlock(nn.Cell):
stride = 1 ,
stride = 1 ,
damping = 0.03 ,
damping = 0.03 ,
loss_scale = 1 ,
loss_scale = 1 ,
frequency = 278 ) :
frequency = 278 ,
batch_size = 32 ) :
super ( ResidualBlock , self ) . __init__ ( )
super ( ResidualBlock , self ) . __init__ ( )
channel = out_channel / / self . expansion
channel = out_channel / / self . expansion
self . conv1 = _conv1x1 ( in_channel , channel , stride = 1 , damping = damping , loss_scale = loss_scale ,
self . conv1 = _conv1x1 ( in_channel , channel , stride = 1 , damping = damping , loss_scale = loss_scale ,
frequency = frequency )
frequency = frequency , batch_size = batch_size )
self . bn1 = _bn ( channel )
self . bn1 = _bn ( channel )
self . conv2 = _conv3x3 ( channel , channel , stride = stride , damping = damping , loss_scale = loss_scale ,
self . conv2 = _conv3x3 ( channel , channel , stride = stride , damping = damping , loss_scale = loss_scale ,
frequency = frequency )
frequency = frequency , batch_size = batch_size )
self . bn2 = _bn ( channel )
self . bn2 = _bn ( channel )
self . conv3 = _conv1x1 ( channel , out_channel , stride = 1 , damping = damping , loss_scale = loss_scale ,
self . conv3 = _conv1x1 ( channel , out_channel , stride = 1 , damping = damping , loss_scale = loss_scale ,
frequency = frequency )
frequency = frequency , batch_size = batch_size )
self . bn3 = _bn_last ( out_channel )
self . bn3 = _bn_last ( out_channel )
self . relu = nn . ReLU ( )
self . relu = nn . ReLU ( )
@ -180,7 +212,8 @@ class ResidualBlock(nn.Cell):
if self . down_sample :
if self . down_sample :
self . down_sample_layer = nn . SequentialCell ( [ _conv1x1 ( in_channel , out_channel , stride ,
self . down_sample_layer = nn . SequentialCell ( [ _conv1x1 ( in_channel , out_channel , stride ,
damping = damping , loss_scale = loss_scale ,
damping = damping , loss_scale = loss_scale ,
frequency = frequency ) ,
frequency = frequency ,
batch_size = batch_size ) ,
_bn ( out_channel ) ] )
_bn ( out_channel ) ] )
self . add = P . TensorAdd ( )
self . add = P . TensorAdd ( )
@ -239,16 +272,19 @@ class ResNet(nn.Cell):
num_classes ,
num_classes ,
damping ,
damping ,
loss_scale ,
loss_scale ,
frequency ) :
frequency ,
batch_size ) :
super ( ResNet , self ) . __init__ ( )
super ( ResNet , self ) . __init__ ( )
if not len ( layer_nums ) == len ( in_channels ) == len ( out_channels ) == 4 :
if not len ( layer_nums ) == len ( in_channels ) == len ( out_channels ) == 4 :
raise ValueError ( " the length of layer_num, in_channels, out_channels list must be 4! " )
raise ValueError ( " the length of layer_num, in_channels, out_channels list must be 4! " )
self . conv1 = _conv7x7 ( 3 , 64 , stride = 2 , damping = damping , loss_scale = loss_scale , frequency = frequency )
self . conv1 = _conv7x7 ( 3 , 64 , stride = 2 , damping = damping , loss_scale = loss_scale ,
frequency = frequency , batch_size = batch_size )
self . bn1 = _bn ( 64 )
self . bn1 = _bn ( 64 )
self . relu = P . ReLU ( )
self . relu = P . ReLU ( )
self . maxpool = P . MaxPoolWithArgmax ( padding = " same " , ksize = 3 , strides = 2 )
# self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
self . maxpool = nn . MaxPool2d ( kernel_size = 3 , stride = 2 , pad_mode = " same " )
self . layer1 = self . _make_layer ( block ,
self . layer1 = self . _make_layer ( block ,
layer_nums [ 0 ] ,
layer_nums [ 0 ] ,
@ -257,7 +293,8 @@ class ResNet(nn.Cell):
stride = strides [ 0 ] ,
stride = strides [ 0 ] ,
damping = damping ,
damping = damping ,
loss_scale = loss_scale ,
loss_scale = loss_scale ,
frequency = frequency )
frequency = frequency ,
batch_size = batch_size )
self . layer2 = self . _make_layer ( block ,
self . layer2 = self . _make_layer ( block ,
layer_nums [ 1 ] ,
layer_nums [ 1 ] ,
in_channel = in_channels [ 1 ] ,
in_channel = in_channels [ 1 ] ,
@ -265,14 +302,16 @@ class ResNet(nn.Cell):
stride = strides [ 1 ] ,
stride = strides [ 1 ] ,
damping = damping ,
damping = damping ,
loss_scale = loss_scale ,
loss_scale = loss_scale ,
frequency = frequency )
frequency = frequency ,
batch_size = batch_size )
self . layer3 = self . _make_layer ( block ,
self . layer3 = self . _make_layer ( block ,
layer_nums [ 2 ] ,
layer_nums [ 2 ] ,
in_channel = in_channels [ 2 ] ,
in_channel = in_channels [ 2 ] ,
out_channel = out_channels [ 2 ] ,
out_channel = out_channels [ 2 ] ,
stride = strides [ 2 ] , damping = damping ,
stride = strides [ 2 ] , damping = damping ,
loss_scale = loss_scale ,
loss_scale = loss_scale ,
frequency = frequency )
frequency = frequency ,
batch_size = batch_size )
self . layer4 = self . _make_layer ( block ,
self . layer4 = self . _make_layer ( block ,
layer_nums [ 3 ] ,
layer_nums [ 3 ] ,
in_channel = in_channels [ 3 ] ,
in_channel = in_channels [ 3 ] ,
@ -280,14 +319,16 @@ class ResNet(nn.Cell):
stride = strides [ 3 ] ,
stride = strides [ 3 ] ,
damping = damping ,
damping = damping ,
loss_scale = loss_scale ,
loss_scale = loss_scale ,
frequency = frequency )
frequency = frequency ,
batch_size = batch_size )
self . mean = P . ReduceMean ( keep_dims = True )
self . mean = P . ReduceMean ( keep_dims = True )
self . flatten = nn . Flatten ( )
self . flatten = nn . Flatten ( )
self . end_point = _fc ( out_channels [ 3 ] , num_classes , damping = damping , loss_scale = loss_scale , frequency = frequency )
self . end_point = _fc ( out_channels [ 3 ] , num_classes , damping = damping , loss_scale = loss_scale ,
frequency = frequency , batch_size = batch_size )
def _make_layer ( self , block , layer_num , in_channel , out_channel , stride ,
def _make_layer ( self , block , layer_num , in_channel , out_channel , stride ,
damping , loss_scale , frequency ):
damping , loss_scale , frequency , batch_size ):
"""
"""
Make stage network of ResNet .
Make stage network of ResNet .
@ -307,12 +348,14 @@ class ResNet(nn.Cell):
layers = [ ]
layers = [ ]
resnet_block = block ( in_channel , out_channel , stride = stride ,
resnet_block = block ( in_channel , out_channel , stride = stride ,
damping = damping , loss_scale = loss_scale , frequency = frequency )
damping = damping , loss_scale = loss_scale , frequency = frequency ,
batch_size = batch_size )
layers . append ( resnet_block )
layers . append ( resnet_block )
for _ in range ( 1 , layer_num ) :
for _ in range ( 1 , layer_num ) :
resnet_block = block ( out_channel , out_channel , stride = 1 ,
resnet_block = block ( out_channel , out_channel , stride = 1 ,
damping = damping , loss_scale = loss_scale , frequency = frequency )
damping = damping , loss_scale = loss_scale , frequency = frequency ,
batch_size = batch_size )
layers . append ( resnet_block )
layers . append ( resnet_block )
return nn . SequentialCell ( layers )
return nn . SequentialCell ( layers )
@ -321,7 +364,7 @@ class ResNet(nn.Cell):
x = self . conv1 ( x )
x = self . conv1 ( x )
x = self . bn1 ( x )
x = self . bn1 ( x )
x = self . relu ( x )
x = self . relu ( x )
c1 , _ = self . maxpool ( x )
c1 = self . maxpool ( x )
c2 = self . layer1 ( c1 )
c2 = self . layer1 ( c1 )
c3 = self . layer2 ( c2 )
c3 = self . layer2 ( c2 )
@ -335,7 +378,7 @@ class ResNet(nn.Cell):
return out
return out
def resnet50 ( class_num = 10 , damping = 0.03 , loss_scale = 1 , frequency = 278 ):
def resnet50 ( class_num = 10 , damping = 0.03 , loss_scale = 1 , frequency = 278 , batch_size = 32 ):
"""
"""
Get ResNet50 neural network .
Get ResNet50 neural network .
@ -356,4 +399,5 @@ def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
class_num ,
class_num ,
damping ,
damping ,
loss_scale ,
loss_scale ,
frequency )
frequency ,
batch_size )