|
|
|
@ -84,8 +84,14 @@ def run_pretrain():
|
|
|
|
|
device_num=device_num)
|
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
|
if bert_net_cfg.num_hidden_layers == 12:
|
|
|
|
|
if bert_net_cfg.use_relative_positions:
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217])
|
|
|
|
|
else:
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205])
|
|
|
|
|
elif bert_net_cfg.num_hidden_layers == 24:
|
|
|
|
|
if bert_net_cfg.use_relative_positions:
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421])
|
|
|
|
|
else:
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
|
|
|
|
|
D.init()
|
|
|
|
|
rank = args_opt.device_id % device_num
|
|
|
|
|