diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 3acac0b5df..748ab7be3d 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -81,7 +81,7 @@ class _BatchNorm(Cell): self.parallel_mode = context.get_auto_parallel_context("parallel_mode") global SYNC_BN_GROUP_NAME # for GlobalBatchNorm - if self.group_device_num != 1 and self.parallel_mode != context.ParallelMode.STAND_ALONE: + if self.group_device_num != 1: self.rank_id = get_rank() self.rank_size = get_group_size() self.device_list = [i for i in range(0, self.rank_size)]