| 
						
						
							
								
							
						
						
					 | 
					 | 
					@ -45,11 +45,18 @@ _current_dir = os.path.dirname(os.path.realpath(__file__))
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					def _set_bert_all_reduce_split():
 | 
					 | 
					 | 
					 | 
					def _set_bert_all_reduce_split():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    """set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
 | 
					 | 
					 | 
					 | 
					    """set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    device_target = context.get_context('device_target')
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    enable_graph_kernel = context.get_context('enable_graph_kernel')
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    device_num = context.get_auto_parallel_context('device_num')
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if bert_net_cfg.num_hidden_layers == 12:
 | 
					 | 
					 | 
					 | 
					    if bert_net_cfg.num_hidden_layers == 12:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if bert_net_cfg.use_relative_positions:
 | 
					 | 
					 | 
					 | 
					        if bert_net_cfg.use_relative_positions:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
 | 
					 | 
					 | 
					 | 
					            context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        else:
 | 
					 | 
					 | 
					 | 
					        else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
 | 
					 | 
					 | 
					 | 
					            context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            if device_target == 'GPU' and enable_graph_kernel and device_num == 8:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                context.set_auto_parallel_context(all_reduce_fusion_config=[180, 205])
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            elif device_target == 'GPU' and enable_graph_kernel and device_num == 16:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					                context.set_auto_parallel_context(all_reduce_fusion_config=[120, 205])
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    elif bert_net_cfg.num_hidden_layers == 24:
 | 
					 | 
					 | 
					 | 
					    elif bert_net_cfg.num_hidden_layers == 24:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if bert_net_cfg.use_relative_positions:
 | 
					 | 
					 | 
					 | 
					        if bert_net_cfg.use_relative_positions:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
 | 
					 | 
					 | 
					 | 
					            context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -119,8 +126,7 @@ def _get_optimizer(args_opt, network):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
 | 
					 | 
					 | 
					 | 
					def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    """Judge whether is suitable to enable graph kernel."""
 | 
					 | 
					 | 
					 | 
					    """Judge whether is suitable to enable graph kernel."""
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
 | 
					 | 
					 | 
					 | 
					    return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        cfg.bert_network == 'base' and (cfg.batch_size == 32 or cfg.batch_size == 64) and \
 | 
					 | 
					 | 
					 | 
					        cfg.bert_network == 'base' and cfg.optimizer == 'AdamWeightDecay'
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        cfg.optimizer == 'AdamWeightDecay'
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
 | 
					 | 
					 | 
					 | 
					def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -131,10 +137,15 @@ def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')
 | 
					 | 
					 | 
					 | 
					            logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					def _check_compute_type(device_target, is_auto_enable_graph_kernel):
 | 
					 | 
					 | 
					 | 
					def _check_compute_type(args_opt, is_auto_enable_graph_kernel):
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and not is_auto_enable_graph_kernel:
 | 
					 | 
					 | 
					 | 
					    if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        logger.warning('Gpu only support fp32 temporarily, run with fp32.')
 | 
					 | 
					 | 
					 | 
					       not is_auto_enable_graph_kernel:
 | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        warning_message = 'Gpu only support fp32 temporarily, run with fp32.'
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        bert_net_cfg.compute_type = mstype.float32
 | 
					 | 
					 | 
					 | 
					        bert_net_cfg.compute_type = mstype.float32
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        if args_opt.enable_lossscale == "true":
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            args_opt.enable_lossscale = "false"
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            warning_message = 'Gpu only support fp32 temporarily, run with fp32 and disable lossscale.'
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        logger.warning(warning_message)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					def argparse_init():
 | 
					 | 
					 | 
					 | 
					def argparse_init():
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -180,6 +191,8 @@ def run_pretrain():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    args_opt = parser.parse_args()
 | 
					 | 
					 | 
					 | 
					    args_opt = parser.parse_args()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
 | 
					 | 
					 | 
					 | 
					    context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    context.set_context(reserve_class_name_in_scope=False)
 | 
					 | 
					 | 
					 | 
					    context.set_context(reserve_class_name_in_scope=False)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    is_auto_enable_graph_kernel = _auto_enable_graph_kernel(args_opt.device_target, args_opt.enable_graph_kernel)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					    _set_graph_kernel_context(args_opt.device_target, args_opt.enable_graph_kernel, is_auto_enable_graph_kernel)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    ckpt_save_dir = args_opt.save_checkpoint_path
 | 
					 | 
					 | 
					 | 
					    ckpt_save_dir = args_opt.save_checkpoint_path
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if args_opt.distribute == "true":
 | 
					 | 
					 | 
					 | 
					    if args_opt.distribute == "true":
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if args_opt.device_target == 'Ascend':
 | 
					 | 
					 | 
					 | 
					        if args_opt.device_target == 'Ascend':
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					 | 
					@ -195,15 +208,12 @@ def run_pretrain():
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        context.reset_auto_parallel_context()
 | 
					 | 
					 | 
					 | 
					        context.reset_auto_parallel_context()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
 | 
					 | 
					 | 
					 | 
					        context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                          device_num=device_num)
 | 
					 | 
					 | 
					 | 
					                                          device_num=device_num)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if args_opt.device_target == 'Ascend':
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        _set_bert_all_reduce_split()
 | 
					 | 
					 | 
					 | 
					        _set_bert_all_reduce_split()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    else:
 | 
					 | 
					 | 
					 | 
					    else:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        rank = 0
 | 
					 | 
					 | 
					 | 
					        rank = 0
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        device_num = 1
 | 
					 | 
					 | 
					 | 
					        device_num = 1
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    is_auto_enable_graph_kernel = _auto_enable_graph_kernel(args_opt.device_target, args_opt.enable_graph_kernel)
 | 
					 | 
					 | 
					 | 
					    _check_compute_type(args_opt, is_auto_enable_graph_kernel)
 | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    _set_graph_kernel_context(args_opt.device_target, args_opt.enable_graph_kernel, is_auto_enable_graph_kernel)
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    _check_compute_type(args_opt.device_target, is_auto_enable_graph_kernel)
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    if args_opt.accumulation_steps > 1:
 | 
					 | 
					 | 
					 | 
					    if args_opt.accumulation_steps > 1:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        logger.info("accumulation steps: {}".format(args_opt.accumulation_steps))
 | 
					 | 
					 | 
					 | 
					        logger.info("accumulation steps: {}".format(args_opt.accumulation_steps))
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					 | 
					
 
 |