|
|
|
@ -74,8 +74,12 @@ class _BatchNorm(Cell):
|
|
|
|
|
management.create_group('group' + str(i), self.rank_list[i])
|
|
|
|
|
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1)
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
self.reduce_mean = P.ReduceMean()
|
|
|
|
|
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
|
|
|
|
self.square = P.Square()
|
|
|
|
|
self.sqrt = P.Sqrt()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.dtype = P.DType()
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
|
|
|
|
|
if context.get_context("enable_ge"):
|
|
|
|
|
self.is_ge_backend = True
|
|
|
|
@ -112,29 +116,34 @@ class _BatchNorm(Cell):
|
|
|
|
|
group_list = [list(i) for i in world_rank_list]
|
|
|
|
|
return group_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _global_sync(self, x, axes, re_shape):
|
|
|
|
|
"""calculate global batch normalization output"""
|
|
|
|
|
x_mean = self.reduce_mean(x, axes)
|
|
|
|
|
x_mean_square = self.reduce_mean(self.square(x), axes)
|
|
|
|
|
global_batch_mean = self.all_reduce(x_mean) / self.group
|
|
|
|
|
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
|
|
|
|
|
global_mean = global_batch_mean
|
|
|
|
|
global_var = global_batch_mean_square - self.square(global_mean)
|
|
|
|
|
var_sqrt = self.sqrt(global_var + self.eps)
|
|
|
|
|
mean_first = (x - global_mean) / var_sqrt
|
|
|
|
|
y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape)
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean)
|
|
|
|
|
tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub)))
|
|
|
|
|
mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var)
|
|
|
|
|
tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2)))
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean))
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance))
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
if self.training and self.use_batch_statistics:
|
|
|
|
|
if self.is_ge_backend:
|
|
|
|
|
if self.is_global:
|
|
|
|
|
x_mean = self.reduce_mean(x)
|
|
|
|
|
x_mean_square = self.reduce_mean(self.square(x))
|
|
|
|
|
global_batch_mean = self.all_reduce(x_mean) / self.group
|
|
|
|
|
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group
|
|
|
|
|
global_mean = global_batch_mean
|
|
|
|
|
global_var = global_batch_mean_square - self.square(global_batch_mean)
|
|
|
|
|
y, batch_mean, batch_var, _, _ = \
|
|
|
|
|
self.bn_train(x,
|
|
|
|
|
self.gamma,
|
|
|
|
|
self.beta,
|
|
|
|
|
None,
|
|
|
|
|
None)
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, global_mean)
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, global_var)
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum)
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean))
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance))
|
|
|
|
|
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
|
|
|
|
|
y = self._global_sync(x, axes, re_shape)
|
|
|
|
|
else:
|
|
|
|
|
y, batch_mean, batch_var, _, _ = \
|
|
|
|
|
self.bn_train(x,
|
|
|
|
@ -172,6 +181,17 @@ def _channel_check(channel, num_channel):
|
|
|
|
|
if channel != num_channel:
|
|
|
|
|
raise ValueError("the input channel is not equal with num_channel")
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _shape_infer(x_shape, num_feature):
|
|
|
|
|
"""global batch normalization shape and axes infer"""
|
|
|
|
|
if len(x_shape) == 4:
|
|
|
|
|
axes = (0, 2, 3)
|
|
|
|
|
re_shape = (1, num_feature, 1, 1)
|
|
|
|
|
else:
|
|
|
|
|
axes = (0,)
|
|
|
|
|
re_shape = (1, num_feature)
|
|
|
|
|
return axes, re_shape
|
|
|
|
|
|
|
|
|
|
class BatchNorm1d(_BatchNorm):
|
|
|
|
|
r"""
|
|
|
|
|
Batch normalization layer over a 2D input.
|
|
|
|
@ -474,6 +494,12 @@ class GroupNorm(Cell):
|
|
|
|
|
num_channels (int): The number of channels per group.
|
|
|
|
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
|
|
|
affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True.
|
|
|
|
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
|
|
|
'he_uniform', etc. Default: 'ones'.
|
|
|
|
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
|
|
|
|
'he_uniform', etc. Default: 'zeros'.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The input feature with shape [N, C, H, W].
|
|
|
|
|