diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 7c46adc1e4..16b1e9b786 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -139,8 +139,9 @@ class _BatchNorm(Cell): tmp_mean = self.mul_mean(mean_sub, self.cast(self.momentum, self.dtype(mean_sub))) mean_sub2 = self.sub_var(self.reshape(self.moving_mean, re_shape), global_var) tmp_variance = self.mul_var(mean_sub2, self.cast(self.momentum, self.dtype(mean_sub2))) - y = F.depend(y, self.assign_sub_mean(self.reshape(self.moving_mean, re_shape), tmp_mean)) - y = F.depend(y, self.assign_sub_var(self.reshape(self.moving_variance, re_shape), tmp_variance)) + y = F.depend(y, self.assign_sub_mean(self.moving_mean, self.reshape(tmp_mean, self.shape(self.moving_mean)))) + y = F.depend(y, self.assign_sub_var(self.moving_variance, + self.reshape(tmp_variance, self.shape(self.moving_variance)))) return y def construct(self, x):