fix GlobalBatchNorm

pull/12804/head
yuchaojie 4 years ago
parent 6f834db9b2
commit 029df0def3

@ -81,7 +81,7 @@ class _BatchNorm(Cell):
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
global SYNC_BN_GROUP_NAME global SYNC_BN_GROUP_NAME
# for GlobalBatchNorm # for GlobalBatchNorm
if self.group_device_num != 1 and self.parallel_mode != context.ParallelMode.STAND_ALONE: if self.group_device_num != 1:
self.rank_id = get_rank() self.rank_id = get_rank()
self.rank_size = get_group_size() self.rank_size = get_group_size()
self.device_list = [i for i in range(0, self.rank_size)] self.device_list = [i for i in range(0, self.rank_size)]

Loading…
Cancel
Save