| 
						
						
							
								
							
						
						
					 | 
					 | 
					@ -71,7 +71,6 @@ class ParallelExecutor(object):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                 num_trainers=1,
 | 
					 | 
					 | 
					 | 
					                 num_trainers=1,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                 trainer_id=0,
 | 
					 | 
					 | 
					 | 
					                 trainer_id=0,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                 **kwargs):
 | 
					 | 
					 | 
					 | 
					                 **kwargs):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if len(kwargs) != 0:
 | 
					 | 
					 | 
					 | 
					        if len(kwargs) != 0:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            err_msg = ""
 | 
					 | 
					 | 
					 | 
					            err_msg = ""
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					            for key in kwargs:
 | 
					 | 
					 | 
					 | 
					            for key in kwargs:
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -130,6 +129,11 @@ class ParallelExecutor(object):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        main = main_program
 | 
					 | 
					 | 
					 | 
					        main = main_program
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        main = main if main else framework.default_main_program()
 | 
					 | 
					 | 
					 | 
					        main = main if main else framework.default_main_program()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        scope = executor.global_scope()
 | 
					 | 
					 | 
					 | 
					        scope = executor.global_scope()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        # FIXME(Yancey1989): it's a temporary approach to determinate the distribute
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        # train program, call self.bcast_param() at the end of each mini-batch.
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.is_dist = True if "recv" in [
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            op.type for op in main.global_block().ops
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        ] else False
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        if share_vars_from and not isinstance(share_vars_from,
 | 
					 | 
					 | 
					 | 
					        if share_vars_from and not isinstance(share_vars_from,
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					                                              ParallelExecutor):
 | 
					 | 
					 | 
					 | 
					                                              ParallelExecutor):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					 | 
					@ -262,6 +266,10 @@ class ParallelExecutor(object):
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        fetch_var_name = '@FETCHED_VAR_NAME@'
 | 
					 | 
					 | 
					 | 
					        fetch_var_name = '@FETCHED_VAR_NAME@'
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        self.executor.run(fetch_list, fetch_var_name)
 | 
					 | 
					 | 
					 | 
					        self.executor.run(fetch_list, fetch_var_name)
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
 | 
					 | 
					 | 
					 | 
					        arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					        if self.is_dist:
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					            self.bcast_params()
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					        return [arr[i] for i in range(len(arr))]
 | 
					 | 
					 | 
					 | 
					        return [arr[i] for i in range(len(arr))]
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					
 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					 | 
					    def bcast_params(self):
 | 
					 | 
					 | 
					 | 
					    def bcast_params(self):
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					 | 
					
 
 |