From 3e44bf797fbb71e323c211955c057a2f83d860b3 Mon Sep 17 00:00:00 2001 From: liangchenghui Date: Wed, 16 Sep 2020 16:52:11 +0800 Subject: [PATCH] Adjust GroupNorm interface --- mindspore/nn/layer/normalization.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 6ca0c0ccdd..7ffbf273cb 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -573,10 +573,10 @@ class GroupNorm(Cell): affine (bool): A bool value, this layer will have learnable affine parameters when set to true. Default: True. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', - 'he_uniform', etc. Default: 'ones'. + 'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be [num_channels]. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', - 'he_uniform', etc. Default: 'zeros'. + 'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be [num_channels]. Inputs: - **input_x** (Tensor) - The input feature with shape [N, C, H, W]. @@ -608,8 +608,8 @@ class GroupNorm(Cell): self.eps = check_typename('eps', eps, (float,)) self.affine = check_bool(affine) - gamma = initializer(gamma_init, [num_channels, 1, 1]) - beta = initializer(beta_init, [num_channels, 1, 1]) + gamma = initializer(gamma_init, num_channels) + beta = initializer(beta_init, num_channels) if self.affine: self.gamma = Parameter(gamma, name='gamma') self.beta = Parameter(beta, name='beta') @@ -633,7 +633,7 @@ class GroupNorm(Cell): std = self.sqrt(var + self.eps) x = (x - mean) / std x = self.reshape(x, (batch, channel, height, width)) - output = x * self.gamma + self.beta + output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1)) return output def construct(self, x):