diff --git a/model_zoo/bert/run_pretrain.py b/model_zoo/bert/run_pretrain.py index ab3d7d63ba..28c021f56f 100644 --- a/model_zoo/bert/run_pretrain.py +++ b/model_zoo/bert/run_pretrain.py @@ -84,9 +84,15 @@ def run_pretrain(): device_num=device_num) from mindspore.parallel._auto_parallel_context import auto_parallel_context if bert_net_cfg.num_hidden_layers == 12: - auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217]) + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) elif bert_net_cfg.num_hidden_layers == 24: - auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421]) + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) D.init() rank = args_opt.device_id % device_num else: