From abfd3f97cb0915e8baa82d6e3d270530dce8e07b Mon Sep 17 00:00:00 2001 From: shibeiji Date: Fri, 12 Jun 2020 16:03:12 +0800 Subject: [PATCH] add allreduce split strategy for NEZHA bert --- model_zoo/bert/run_pretrain.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/model_zoo/bert/run_pretrain.py b/model_zoo/bert/run_pretrain.py index ab3d7d63ba..28c021f56f 100644 --- a/model_zoo/bert/run_pretrain.py +++ b/model_zoo/bert/run_pretrain.py @@ -84,9 +84,15 @@ def run_pretrain(): device_num=device_num) from mindspore.parallel._auto_parallel_context import auto_parallel_context if bert_net_cfg.num_hidden_layers == 12: - auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217]) + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) elif bert_net_cfg.num_hidden_layers == 24: - auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) + if bert_net_cfg.use_relative_positions: + auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421]) + else: + auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) D.init() rank = args_opt.device_id % device_num else: