| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -15,6 +15,7 @@ limitations under the License. */
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					#include <algorithm>
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					#include <atomic>
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					#include "paddle/framework/data_transform.h"
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					#include "paddle/framework/executor.h"
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					#include "paddle/framework/lod_tensor_array.h"
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					#include "paddle/framework/operator.h"
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                 expected_kernel_key);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  kernel_iter->second->Compute(ctx);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (actual_kernel_key == expected_kernel_key) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    kernel_iter->second->Compute(ctx);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  } else {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    Scope& op_scope = scope.NewScope();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto input_vars = this->InputVars();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    for (auto var_name : input_vars) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      op_scope.Var(var_name);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    // TODO(qijun) get appropriate DeviceContext from DeviceContext pool
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    platform::DeviceContext* trans_dev_ctx = nullptr;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    // TODO(qijun) get appropriate DataTransformFN from global map
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    framework::DataTransformFN trans_fun = nullptr;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    // Wait for transform starting
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    dev_ctx->Wait();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    for (auto var_name : input_vars) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                op_scope.FindVar(var_name));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    // Wait for data transform finishing
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    for (auto ctx : trans_dev_ctx_vec) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      ctx->Wait();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    // Create a new ExecutionContext
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    kernel_iter->second->Compute(op_ctx);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					OpKernelType OperatorWithKernel::GetActualKernelType(
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |