@ -128,8 +128,6 @@ class _BatchNorm(Cell):
def _global_sync(self, x, axes, re_shape):
"""calculate global batch normalization output"""
axes = (0, 2, 3)
re_shape = (1, self.num_features, 1, 1)
x_mean = self.reduce_mean(x, axes)
x_mean_square = self.reduce_mean(self.square(x), axes)
global_batch_mean = self.all_reduce(x_mean) / self.group