| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -179,11 +179,6 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  VLOG(3) << place << " " << DebugStringEx(&scope);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					void OperatorBase::RunPrepared(const RuntimeContext& ctx,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                               const platform::Place& place) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  RunImplPrepared(ctx, place);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					bool OperatorBase::HasInputs(const std::string& name) const {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  return inputs_.find(name) != inputs_.end();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -958,51 +953,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					void OperatorWithKernel::RunImplPrepared(const RuntimeContext& ctx,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                         const platform::Place& place) const {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  Scope dummy_scope;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto* dev_ctx = pool.Get(place);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  // check if op[type] has kernel registered.
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto& all_op_kernels = AllOpKernels();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto kernels_iter = all_op_kernels.find(type_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (kernels_iter == all_op_kernels.end()) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_THROW(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        "There are no kernels which are registered in the %s operator.", type_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  OpKernelMap& kernels = kernels_iter->second;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto expected_kernel_key = this->GetExpectedKernelType(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      ExecutionContext(*this, dummy_scope, *dev_ctx, ctx));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto kernel_iter = kernels.find(expected_kernel_key);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					#ifdef PADDLE_WITH_MKLDNN
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (kernel_iter == kernels.end() &&
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    expected_kernel_key.library_type_ = LibraryType::kPlain;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    kernel_iter = kernels.find(expected_kernel_key);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					#endif
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (kernel_iter == kernels.end()) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_THROW("op %s does not have kernel for %s", type_,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                 KernelTypeToString(expected_kernel_key));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    dev_ctx = pool.Get(expected_kernel_key.place_);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  RuntimeInferShapeContext infer_shape_ctx(*this, dummy_scope, ctx);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  this->InferShape(&infer_shape_ctx);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  kernel_iter->second(ExecutionContext(*this, dummy_scope, *dev_ctx, ctx));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					void OperatorWithKernel::TransferInplaceVarsBack(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const Scope& scope, const std::vector<std::string>& inplace_vars,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const Scope& transfer_scope) const {
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |