From d2b04664cad59608eb4754d2df93eaeaf84d7aca Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 20:58:59 -0400 Subject: [PATCH] add global batch normalization --- mindspore/nn/layer/normalization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 2b55147cf1..c85b945a0d 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -110,9 +110,9 @@ class _BatchNorm(Cell): raise NotImplementedError def list_group(self, world_rank, group_size): - if group_size > get_local_rank_size(): + if group_size > get_group_size(): raise ValueError("group size can not be greater than local rank size, group size is {}, " - "local_rank_size is {}".format(group_size, get_local_rank_size())) + "local_rank_size is {}".format(group_size, get_group_size())) if len(world_rank) % group_size != 0: raise ValueError("please make your group size correct.") world_rank_list = zip(*(iter(world_rank),) *group_size)