add global batch normalization

pull/399/head
zhaojichen 5 years ago
parent 27c3076849
commit b5e98042c5

@ -111,8 +111,8 @@ class _BatchNorm(Cell):
def list_group(self, world_rank, group_size):
if group_size > get_local_rank_size():
raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format(
group_size, get_local_rank_size()))
raise ValueError("group size can not be greater than local rank size, group size is {}, "
"local_rank_size is {}".format(group_size, get_local_rank_size()))
if len(world_rank) % group_size != 0:
raise ValueError("please make your group size correct.")
world_rank_list = zip(*(iter(world_rank),) *group_size)

Loading…
Cancel
Save