!2061 set allreduce fusion split indices for NEZHA bert

Merge pull request !2061 from shibeiji/master
pull/2061/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d251ba9e11

@ -84,8 +84,14 @@ def run_pretrain():
device_num=device_num) device_num=device_num)
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.num_hidden_layers == 12:
if bert_net_cfg.use_relative_positions:
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217])
else:
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205])
elif bert_net_cfg.num_hidden_layers == 24: elif bert_net_cfg.num_hidden_layers == 24:
if bert_net_cfg.use_relative_positions:
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421])
else:
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
D.init() D.init()
rank = args_opt.device_id % device_num rank = args_opt.device_id % device_num

Loading…
Cancel
Save