From 90e2f7555d9d3778b6d02ef91e08ec160595d6a4 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 28 Apr 2020 09:27:57 -0400 Subject: [PATCH 1/7] fix globalbatchnorm bug --- mindspore/nn/layer/normalization.py | 71 +++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 19 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index a4062a7a54..f9316c7c11 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -76,6 +76,10 @@ class _BatchNorm(Cell): self.shape = P.Shape() self.reduce_mean = P.ReduceMean() self.square = P.Square() + self.sqrt = P.Sqrt() + self.cast = P.Cast() + self.dtype = P.DType() + self.reshape = P.Reshape() if context.get_context("enable_ge"): self.is_ge_backend = True @@ -112,29 +116,52 @@ class _BatchNorm(Cell): group_list = [list(i) for i in world_rank_list] return group_list + def _global_sync(self, x): + if len(self.shape(x)) == 4: + axes = (0, 2, 3) + re_shape = (1, self.num_features, 1, 1) + x_mean = self.reduce_mean(x, axes) + x_mean_square = self.reduce_mean(self.square(x), axes) + global_batch_mean = self.all_reduce(x_mean) / self.group + global_batch_mean_square = self.all_reduce(x_mean_square) / self.group + global_mean = global_batch_mean + global_var = global_batch_mean_square - self.square(global_mean) + var_sqrt = self.sqrt(global_var + self.eps) + mean_first = (x - global_mean) / var_sqrt + y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) + + mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) + tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) + mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) + tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) + y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) + y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) + else: + axes = (0,) + re_shape = (1, self.num_features) + x_mean = self.reduce_mean(x, axes) + x_mean_square = self.reduce_mean(self.square(x), axes) + global_batch_mean = self.all_reduce(x_mean) / self.group + global_batch_mean_square = self.all_reduce(x_mean_square) / self.group + global_mean = global_batch_mean + global_var = global_batch_mean_square - self.square(global_mean) + var_sqrt = self.sqrt(global_var + self.eps) + mean_first = (x - global_mean) / var_sqrt + y = mean_first * self.gamma + self.beta + + mean_sub = self.sub_mean(self.moving_mean, global_mean) + temp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) + mean_sub2 = self.sub_var(self.moving_variance, global_var) + temp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) + y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), temp_mean)) + y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), temp_variance)) + return y + def construct(self, x): if self.training and self.use_batch_statistics: if self.is_ge_backend: if self.is_global: - x_mean = self.reduce_mean(x) - x_mean_square = self.reduce_mean(self.square(x)) - global_batch_mean = self.all_reduce(x_mean) / self.group - global_batch_mean_square = self.all_reduce(x_mean_square) / self.group - global_mean = global_batch_mean - global_var = global_batch_mean_square - self.square(global_batch_mean) - y, batch_mean, batch_var, _, _ = \ - self.bn_train(x, - self.gamma, - self.beta, - None, - None) - - mean_sub = self.sub_mean(self.moving_mean, global_mean) - temp_mean = self.mul_mean(mean_sub, self.momentum) - mean_sub2 = self.sub_var(self.moving_variance, global_var) - temp_variance = self.mul_var(mean_sub2, self.momentum) - y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) - y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) + y = self._global_sync(x) else: y, batch_mean, batch_var, _, _ = \ self.bn_train(x, @@ -474,6 +501,12 @@ class GroupNorm(Cell): num_channels (int): The number of channels per group. eps (float): A value added to the denominator for numerical stability. Default: 1e-5. affine (bool): A bool value, this layer will has learnable affine parameters when set to true. Default: True. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'ones'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. + The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', + 'he_uniform', etc. Default: 'zeros'. Inputs: - **input_x** (Tensor) - The input feature with shape [N, C, H, W]. From eb46dd9198b358a8fac4fbceff260f6a363f3b8a Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 28 Apr 2020 09:56:28 -0400 Subject: [PATCH 2/7] fix globalbatchnorm bug --- mindspore/nn/layer/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index f9316c7c11..9d623bc6fd 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -74,7 +74,7 @@ class _BatchNorm(Cell): management.create_group('group' + str(i), self.rank_list[i]) self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) self.shape = P.Shape() - self.reduce_mean = P.ReduceMean() + self.reduce_mean = P.ReduceMean(keep_dims=True) self.square = P.Square() self.sqrt = P.Sqrt() self.cast = P.Cast() From 0ba35eaec3c73867056e05259e36b44520355535 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 28 Apr 2020 21:09:15 -0400 Subject: [PATCH 3/7] fix globalbatchnorm bug --- mindspore/nn/layer/normalization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 9d623bc6fd..66f17e3f38 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -117,6 +117,7 @@ class _BatchNorm(Cell): return group_list def _global_sync(self, x): + """calculate global batch normalization output""" if len(self.shape(x)) == 4: axes = (0, 2, 3) re_shape = (1, self.num_features, 1, 1) From 8ca1f87a49567d90f83850e268b095a843cf69d0 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 29 Apr 2020 02:56:42 -0400 Subject: [PATCH 4/7] fix globalbatchnorm bug --- mindspore/nn/layer/normalization.py | 63 +++++++++++++---------------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 66f17e3f38..dd4ac67273 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -116,53 +116,44 @@ class _BatchNorm(Cell): group_list = [list(i) for i in world_rank_list] return group_list - def _global_sync(self, x): - """calculate global batch normalization output""" + def _shape_infer(self, x): + """global batch normalization shape and axes infer""" if len(self.shape(x)) == 4: - axes = (0, 2, 3) + axes = (0,2,3) re_shape = (1, self.num_features, 1, 1) - x_mean = self.reduce_mean(x, axes) - x_mean_square = self.reduce_mean(self.square(x), axes) - global_batch_mean = self.all_reduce(x_mean) / self.group - global_batch_mean_square = self.all_reduce(x_mean_square) / self.group - global_mean = global_batch_mean - global_var = global_batch_mean_square - self.square(global_mean) - var_sqrt = self.sqrt(global_var + self.eps) - mean_first = (x - global_mean) / var_sqrt - y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) - - mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) - tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) - mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) - tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) - y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) - y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) else: axes = (0,) re_shape = (1, self.num_features) - x_mean = self.reduce_mean(x, axes) - x_mean_square = self.reduce_mean(self.square(x), axes) - global_batch_mean = self.all_reduce(x_mean) / self.group - global_batch_mean_square = self.all_reduce(x_mean_square) / self.group - global_mean = global_batch_mean - global_var = global_batch_mean_square - self.square(global_mean) - var_sqrt = self.sqrt(global_var + self.eps) - mean_first = (x - global_mean) / var_sqrt - y = mean_first * self.gamma + self.beta - - mean_sub = self.sub_mean(self.moving_mean, global_mean) - temp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) - mean_sub2 = self.sub_var(self.moving_variance, global_var) - temp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) - y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), temp_mean)) - y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), temp_variance)) + return axes, re_shape + + def _global_sync(self, x, axes, re_shape): + """calculate global batch normalization output""" + axes = (0, 2, 3) + re_shape = (1, self.num_features, 1, 1) + x_mean = self.reduce_mean(x, axes) + x_mean_square = self.reduce_mean(self.square(x), axes) + global_batch_mean = self.all_reduce(x_mean) / self.group + global_batch_mean_square = self.all_reduce(x_mean_square) / self.group + global_mean = global_batch_mean + global_var = global_batch_mean_square - self.square(global_mean) + var_sqrt = self.sqrt(global_var + self.eps) + mean_first = (x - global_mean) / var_sqrt + y = mean_first * self.reshape(self.gamma, re_shape) + self.reshape(self.beta, re_shape) + + mean_sub = self.sub_mean(self.reshape(self.moving_mean, re_shape), global_mean) + tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) + mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) + tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) + y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) + y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) return y def construct(self, x): if self.training and self.use_batch_statistics: if self.is_ge_backend: if self.is_global: - y = self._global_sync(x) + axes, re_shape = self._shape_infer(x) + y = self._global_sync(x, axes, re_shape) else: y, batch_mean, batch_var, _, _ = \ self.bn_train(x, From 8261cfd01902be5a0f4f14a16ca51d0938dd6d3f Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 29 Apr 2020 03:00:53 -0400 Subject: [PATCH 5/7] fix globalbatchnorm bug --- mindspore/nn/layer/normalization.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index dd4ac67273..6e92369550 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -128,8 +128,6 @@ class _BatchNorm(Cell): def _global_sync(self, x, axes, re_shape): """calculate global batch normalization output""" - axes = (0, 2, 3) - re_shape = (1, self.num_features, 1, 1) x_mean = self.reduce_mean(x, axes) x_mean_square = self.reduce_mean(self.square(x), axes) global_batch_mean = self.all_reduce(x_mean) / self.group From 6c9a54afa12ecc722bd29d3a728a3923205f0c03 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 29 Apr 2020 03:34:58 -0400 Subject: [PATCH 6/7] fix globalbatchnorm bug --- mindspore/nn/layer/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 6e92369550..2a1ca28ed4 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -119,7 +119,7 @@ class _BatchNorm(Cell): def _shape_infer(self, x): """global batch normalization shape and axes infer""" if len(self.shape(x)) == 4: - axes = (0,2,3) + axes = (0, 2, 3) re_shape = (1, self.num_features, 1, 1) else: axes = (0,) From 7b81ca68dc8ee17a6daf624e3eb215ec3cd48f92 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 29 Apr 2020 04:53:23 -0400 Subject: [PATCH 7/7] fix globalbatchnorm bug --- mindspore/nn/layer/normalization.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 2a1ca28ed4..7a102b0bbe 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -116,15 +116,7 @@ class _BatchNorm(Cell): group_list = [list(i) for i in world_rank_list] return group_list - def _shape_infer(self, x): - """global batch normalization shape and axes infer""" - if len(self.shape(x)) == 4: - axes = (0, 2, 3) - re_shape = (1, self.num_features, 1, 1) - else: - axes = (0,) - re_shape = (1, self.num_features) - return axes, re_shape + def _global_sync(self, x, axes, re_shape): """calculate global batch normalization output""" @@ -150,7 +142,7 @@ class _BatchNorm(Cell): if self.training and self.use_batch_statistics: if self.is_ge_backend: if self.is_global: - axes, re_shape = self._shape_infer(x) + axes, re_shape = _shape_infer(F.shape(x), self.num_features) y = self._global_sync(x, axes, re_shape) else: y, batch_mean, batch_var, _, _ = \ @@ -189,6 +181,17 @@ def _channel_check(channel, num_channel): if channel != num_channel: raise ValueError("the input channel is not equal with num_channel") +@constexpr +def _shape_infer(x_shape, num_feature): + """global batch normalization shape and axes infer""" + if len(x_shape) == 4: + axes = (0, 2, 3) + re_shape = (1, num_feature, 1, 1) + else: + axes = (0,) + re_shape = (1, num_feature) + return axes, re_shape + class BatchNorm1d(_BatchNorm): r""" Batch normalization layer over a 2D input.