|
|
|
@ -478,9 +478,25 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
this->InferShape(&infer_shape_ctx);
|
|
|
|
|
|
|
|
|
|
ExecutionContext ctx(*this, scope, dev_ctx);
|
|
|
|
|
auto& opKernel = AllOpKernels().at(type_).at(
|
|
|
|
|
OpKernelKey(IndicateDataType(ctx), dev_ctx));
|
|
|
|
|
opKernel->Compute(ctx);
|
|
|
|
|
|
|
|
|
|
// check if op[type] has kernel registered.
|
|
|
|
|
auto& all_op_kernels = AllOpKernels();
|
|
|
|
|
auto kernels_iter = all_op_kernels.find(type_);
|
|
|
|
|
if (kernels_iter == all_op_kernels.end()) {
|
|
|
|
|
PADDLE_THROW("op[%s] has no kernel", type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// check if op[type] have kernel for kernel_key
|
|
|
|
|
OpKernelMap& kernels = kernels_iter->second;
|
|
|
|
|
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
|
|
|
|
|
auto kernel_iter = kernels.find(kernel_key);
|
|
|
|
|
|
|
|
|
|
if (kernel_iter == kernels.end()) {
|
|
|
|
|
PADDLE_THROW("op[%s] has no kernel with kernel_key[%s]", type_,
|
|
|
|
|
kernel_key);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel_iter->second->Compute(ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
|
|
|
|
@ -529,5 +545,8 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::ostream& operator<<(std::ostream& os,
|
|
|
|
|
const OperatorWithKernel::OpKernelKey& kernel_key);
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|