|
|
|
@ -20,8 +20,11 @@ from mindspore.common.initializer import initializer
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
from mindspore._checkparam import check_int_positive, check_bool, check_typename
|
|
|
|
|
from mindspore._checkparam import check_bool, check_typename
|
|
|
|
|
from mindspore._extends import cell_attr_register
|
|
|
|
|
from mindspore.communication.management import get_group_size, get_rank
|
|
|
|
|
from mindspore.communication import management
|
|
|
|
|
from mindspore._checkparam import check_int_positive
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -30,6 +33,7 @@ class _BatchNorm(Cell):
|
|
|
|
|
@cell_attr_register
|
|
|
|
|
def __init__(self,
|
|
|
|
|
num_features,
|
|
|
|
|
group=1,
|
|
|
|
|
eps=1e-5,
|
|
|
|
|
momentum=0.9,
|
|
|
|
|
affine=True,
|
|
|
|
@ -56,6 +60,21 @@ class _BatchNorm(Cell):
|
|
|
|
|
gamma_init, num_features), name="gamma", requires_grad=affine)
|
|
|
|
|
self.beta = Parameter(initializer(
|
|
|
|
|
beta_init, num_features), name="beta", requires_grad=affine)
|
|
|
|
|
self.group = check_int_positive(group)
|
|
|
|
|
if self.group != 1:
|
|
|
|
|
self.rank_id = get_rank()
|
|
|
|
|
self.rank_size = get_group_size()
|
|
|
|
|
self.device_list = [i for i in range(0, self.rank_size)]
|
|
|
|
|
self.rank_list = self.list_group(self.device_list, self.group)
|
|
|
|
|
self.rank_list_idx = len(self.rank_list)
|
|
|
|
|
for i in range(self.rank_list_idx):
|
|
|
|
|
if self.rank_id in self.rank_list[i] and self.group != 1:
|
|
|
|
|
self.is_global = True
|
|
|
|
|
management.create_group('group' + str(i), self.rank_list[i])
|
|
|
|
|
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
self.reduce_mean = P.ReduceMean()
|
|
|
|
|
self.square = P.Square()
|
|
|
|
|
|
|
|
|
|
if context.get_context("enable_ge"):
|
|
|
|
|
self.is_ge_backend = True
|
|
|
|
@ -82,22 +101,53 @@ class _BatchNorm(Cell):
|
|
|
|
|
def _check_data_dim(self, x):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def list_group(self, world_rank, group_size):
|
|
|
|
|
if group_size > get_group_size():
|
|
|
|
|
raise ValueError("group size can not be greater than local rank size, group size is {}, "
|
|
|
|
|
"local_rank_size is {}".format(group_size, get_group_size()))
|
|
|
|
|
if len(world_rank) % group_size != 0:
|
|
|
|
|
raise ValueError("please make your group size correct.")
|
|
|
|
|
world_rank_list = zip(*(iter(world_rank),) *group_size)
|
|
|
|
|
group_list = [list(i) for i in world_rank_list]
|
|
|
|
|
return group_list
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
if self.training and self.use_batch_statistics:
|
|
|
|
|
if self.is_ge_backend:
|
|
|
|
|
y, batch_mean, batch_var, _, _ = \
|
|
|
|
|
self.bn_train(x,
|
|
|
|
|
self.gamma,
|
|
|
|
|
self.beta,
|
|
|
|
|
None,
|
|
|
|
|
None)
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
|
|
|
|
|
if self.is_global:
|
|
|
|
|
x_mean = self.reduce_mean(x)
|
|
|
|
|
x_mean_square = self.reduce_mean(self.square(x))
|
|
|
|
|
global_batch_mean = self.all_reduce(x_mean) / self.group
|
|
|
|
|
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
|
|
|
|
|
global_mean = global_batch_mean
|
|
|
|
|
global_var = global_batch_mean_square - self.square(global_batch_mean)
|
|
|
|
|
y, batch_mean, batch_var, _, _ = \
|
|
|
|
|
self.bn_train(x,
|
|
|
|
|
self.gamma,
|
|
|
|
|
self.beta,
|
|
|
|
|
None,
|
|
|
|
|
None)
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, global_mean)
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, global_var)
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
|
|
|
|
|
else:
|
|
|
|
|
y, batch_mean, batch_var, _, _ = \
|
|
|
|
|
self.bn_train(x,
|
|
|
|
|
self.gamma,
|
|
|
|
|
self.beta,
|
|
|
|
|
None,
|
|
|
|
|
None)
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var)
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
|
|
|
|
|
else:
|
|
|
|
|
y = self.bn_train(x,
|
|
|
|
|
self.gamma,
|
|
|
|
@ -221,6 +271,55 @@ class BatchNorm2d(_BatchNorm):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GlobalBatchNorm(_BatchNorm):
|
|
|
|
|
r"""
|
|
|
|
|
Global normalization layer over a N-dimension input.
|
|
|
|
|
|
|
|
|
|
Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation
|
|
|
|
|
only normalize the data within each device. Global normalization will normalize the input within the group.
|
|
|
|
|
It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
|
|
|
|
|
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
|
|
|
|
|
feature using a mini-batch of data and the learned parameters which can be described in the following formula.
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
num_features (int): `C` from an expected input of size (N, C, H, W).
|
|
|
|
|
group (int): The number of device in each group.
|
|
|
|
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
|
|
|
momentum (float): A floating hyperparameter of the momentum for the
|
|
|
|
|
running_mean and running_var computation. Default: 0.9.
|
|
|
|
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
|
|
|
'he_uniform', etc. Default: 'ones'.
|
|
|
|
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
|
|
|
'he_uniform', etc. Default: 'zeros'.
|
|
|
|
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
|
|
|
'he_uniform', etc. Default: 'zeros'.
|
|
|
|
|
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
|
|
|
'he_uniform', etc. Default: 'ones'.
|
|
|
|
|
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
|
|
|
|
|
the mean value and variance value of specified value. Default: True.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4)
|
|
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
|
|
|
|
|
>>> global_bn_op(input)
|
|
|
|
|
"""
|
|
|
|
|
def _check_data_dim(self, x):
|
|
|
|
|
if x.dim == 0:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
class LayerNorm(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
Applies Layer Normalization over a mini-batch of inputs.
|
|
|
|
|