|
|
|
@ -33,7 +33,6 @@ class _BatchNorm(Cell):
|
|
|
|
|
@cell_attr_register
|
|
|
|
|
def __init__(self,
|
|
|
|
|
num_features,
|
|
|
|
|
group=1,
|
|
|
|
|
eps=1e-5,
|
|
|
|
|
momentum=0.9,
|
|
|
|
|
affine=True,
|
|
|
|
@ -41,7 +40,8 @@ class _BatchNorm(Cell):
|
|
|
|
|
beta_init='zeros',
|
|
|
|
|
moving_mean_init='zeros',
|
|
|
|
|
moving_var_init='ones',
|
|
|
|
|
use_batch_statistics=True):
|
|
|
|
|
use_batch_statistics=True,
|
|
|
|
|
group=1):
|
|
|
|
|
super(_BatchNorm, self).__init__()
|
|
|
|
|
if num_features < 1:
|
|
|
|
|
raise ValueError("num_features must be at least 1")
|
|
|
|
@ -214,6 +214,25 @@ class BatchNorm1d(_BatchNorm):
|
|
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32)
|
|
|
|
|
>>> net(input)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self,
|
|
|
|
|
num_features,
|
|
|
|
|
eps=1e-5,
|
|
|
|
|
momentum=0.9,
|
|
|
|
|
affine=True,
|
|
|
|
|
gamma_init='ones',
|
|
|
|
|
beta_init='zeros',
|
|
|
|
|
moving_mean_init='zeros',
|
|
|
|
|
moving_var_init='ones',
|
|
|
|
|
use_batch_statistics=True):
|
|
|
|
|
super(BatchNorm1d, self).__init__(num_features,
|
|
|
|
|
eps,
|
|
|
|
|
momentum,
|
|
|
|
|
affine,
|
|
|
|
|
gamma_init,
|
|
|
|
|
beta_init,
|
|
|
|
|
moving_mean_init,
|
|
|
|
|
moving_var_init,
|
|
|
|
|
use_batch_statistics)
|
|
|
|
|
def _check_data_dim(self, x):
|
|
|
|
|
if x.dim() != 2:
|
|
|
|
|
pass
|
|
|
|
@ -266,6 +285,25 @@ class BatchNorm2d(_BatchNorm):
|
|
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
|
|
|
|
|
>>> net(input)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self,
|
|
|
|
|
num_features,
|
|
|
|
|
eps=1e-5,
|
|
|
|
|
momentum=0.9,
|
|
|
|
|
affine=True,
|
|
|
|
|
gamma_init='ones',
|
|
|
|
|
beta_init='zeros',
|
|
|
|
|
moving_mean_init='zeros',
|
|
|
|
|
moving_var_init='ones',
|
|
|
|
|
use_batch_statistics=True):
|
|
|
|
|
super(BatchNorm2d, self).__init__(num_features,
|
|
|
|
|
eps,
|
|
|
|
|
momentum,
|
|
|
|
|
affine,
|
|
|
|
|
gamma_init,
|
|
|
|
|
beta_init,
|
|
|
|
|
moving_mean_init,
|
|
|
|
|
moving_var_init,
|
|
|
|
|
use_batch_statistics)
|
|
|
|
|
def _check_data_dim(self, x):
|
|
|
|
|
if x.dim() != 4:
|
|
|
|
|
pass
|
|
|
|
@ -316,6 +354,30 @@ class GlobalBatchNorm(_BatchNorm):
|
|
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
|
|
|
|
|
>>> global_bn_op(input)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self,
|
|
|
|
|
num_features,
|
|
|
|
|
eps=1e-5,
|
|
|
|
|
momentum=0.9,
|
|
|
|
|
affine=True,
|
|
|
|
|
gamma_init='ones',
|
|
|
|
|
beta_init='zeros',
|
|
|
|
|
moving_mean_init='zeros',
|
|
|
|
|
moving_var_init='ones',
|
|
|
|
|
use_batch_statistics=True,
|
|
|
|
|
group=1):
|
|
|
|
|
super(GlobalBatchNorm, self).__init__(num_features,
|
|
|
|
|
eps,
|
|
|
|
|
momentum,
|
|
|
|
|
affine,
|
|
|
|
|
gamma_init,
|
|
|
|
|
beta_init,
|
|
|
|
|
moving_mean_init,
|
|
|
|
|
moving_var_init,
|
|
|
|
|
use_batch_statistics,
|
|
|
|
|
group)
|
|
|
|
|
self.group = check_int_positive(group)
|
|
|
|
|
if self.group <= 1:
|
|
|
|
|
raise ValueError("the number of group must be greater than 1.")
|
|
|
|
|
def _check_data_dim(self, x):
|
|
|
|
|
if x.dim == 0:
|
|
|
|
|
pass
|
|
|
|
|