@ -402,19 +402,28 @@ void OperatorWithKernel::Run(const Scope& scope,
OpKernelMap & kernels = kernels_iter - > second ;
ExecutionContext ctx ( * this , scope , * dev_ctx ) ;
auto kernel_key = GetKernelType ( ctx ) ;
auto kernel_iter = kernels . find ( kernel_key ) ;
auto actual_kernel_key = GetActualKernelType ( ctx ) ;
auto expected_kernel_key = GetExpectedKernelType ( actual_kernel_key ) ;
auto kernel_iter = kernels . find ( expected_kernel_key ) ;
if ( kernel_iter = = kernels . end ( ) ) {
PADDLE_THROW ( " The operator %s does not support %s " , type_ , kernel_key ) ;
PADDLE_THROW ( " The operator %s does not support %s " , type_ ,
expected_kernel_key ) ;
}
kernel_iter - > second - > Compute ( ctx ) ;
}
OpKernelType OperatorWithKernel : : GetKernelType (
OpKernelType OperatorWithKernel : : GetActualKernelType (
const ExecutionContext & ctx ) const {
return OpKernelType ( IndicateDataType ( ctx ) , ctx . GetPlace ( ) ) ;
}
OpKernelType OperatorWithKernel : : GetExpectedKernelType (
const OpKernelType & actual_kernel_type ) const {
return actual_kernel_type ;
}
proto : : DataType OperatorWithKernel : : IndicateDataType (
const ExecutionContext & ctx ) const {
auto & scope = ctx . scope ( ) ;