|
|
|
@ -18,8 +18,9 @@ from mindspore.ops import functional as F
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
import mindspore.common.dtype as DT
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
import mindspore.context as context
|
|
|
|
|
from mindspore._checkparam import check_int_positive, check_bool,check_typename
|
|
|
|
|
from mindspore._extends import cell_attr_register
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
|
|
|
|
@ -58,7 +59,7 @@ class _BatchNorm(Cell):
|
|
|
|
|
|
|
|
|
|
if context.get_context("enable_ge"):
|
|
|
|
|
self.is_ge_backend = True
|
|
|
|
|
self.momentum = Tensor(1.0 - momentum, DT.float32)
|
|
|
|
|
self.momentum = Tensor(1.0 - momentum, mstype.float32)
|
|
|
|
|
self.bn_train = P.BatchNorm(is_training=True,
|
|
|
|
|
epsilon=self.eps)
|
|
|
|
|
else:
|
|
|
|
@ -289,3 +290,71 @@ class LayerNorm(Cell):
|
|
|
|
|
s = 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
|
|
|
|
|
self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
class GroupNorm(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
Group Normalization over a mini-batch of inputs.
|
|
|
|
|
|
|
|
|
|
Group normalization is widely used in recurrent neural networks. It applies
|
|
|
|
|
normalization over a mini-batch of inputs for each single training case as described
|
|
|
|
|
in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group normalization
|
|
|
|
|
divides the channels into groups and computes within each group the mean and variance for normalization,
|
|
|
|
|
and it performs very stable over a wide range of batch size. It can be described using the following formula.
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
num_groups (int): The number of groups to be divided along the channel dimension.
|
|
|
|
|
num_channels (int): The number of channels per group.
|
|
|
|
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
|
|
|
|
|
affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The input feature with shape [N, C, H, W].
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> goup_norm_op = nn.GroupNorm(16, 64)
|
|
|
|
|
>>> x = Tensor(np.ones([1, 64, 256, 256], np.float32))
|
|
|
|
|
>>> goup_norm_op(x)
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True):
|
|
|
|
|
super(GroupNorm, self).__init__()
|
|
|
|
|
self.num_groups = check_int_positive(num_groups)
|
|
|
|
|
self.num_channels = check_int_positive(num_channels)
|
|
|
|
|
if num_channels % num_groups != 0:
|
|
|
|
|
raise ValueError("num_channels should be divided by num_groups")
|
|
|
|
|
self.eps = Tensor(check_typename('eps', eps, (float,)),mstype.float32)
|
|
|
|
|
self.affine = check_bool(affine)
|
|
|
|
|
|
|
|
|
|
gamma = initializer('ones', [num_channels, 1, 1], mstype.float32)
|
|
|
|
|
beta = initializer('zeros', [num_channels, 1, 1], mstype.float32)
|
|
|
|
|
if self.affine:
|
|
|
|
|
self.gamma = Parameter(gamma, name='gamma')
|
|
|
|
|
self.beta = Parameter(beta, name='beta')
|
|
|
|
|
else:
|
|
|
|
|
self.gamma = gamma
|
|
|
|
|
self.beta = beta
|
|
|
|
|
self.shape = F.shape
|
|
|
|
|
self.reshape = F.reshape
|
|
|
|
|
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
|
|
|
|
self.square = F.square
|
|
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
|
|
|
|
self.sqrt = P.Sqrt()
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
batch,channel,height,width = self.shape(x)
|
|
|
|
|
x = self.reshape(x,(batch, self.num_groups,channel*height*width/self.num_groups))
|
|
|
|
|
mean = self.reduce_mean(x, 2)
|
|
|
|
|
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)
|
|
|
|
|
std = self.sqrt(var + self.eps)
|
|
|
|
|
x = (x - mean) / std
|
|
|
|
|
x = self.reshape(x, (batch, channel, height, width))
|
|
|
|
|
output = x * self.gamma + self.beta
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def extend_repr(self):
|
|
|
|
|
return 'num_groups={}, num_channels={}'.format(self.num_groups,self.num_channels)
|