diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index f0278cc49c..751e150845 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -76,8 +76,9 @@ class OpRegistry { template struct OpKernelRegistrarFunctor; -template -inline void RegisterKernelClass(const char* op_type, const char* library_type) { +template +inline void RegisterKernelClass(const char* op_type, const char* library_type, + Func func) { std::string library(library_type); std::string data_layout = "ANYLAYOUT"; if (library == "MKLDNN") { @@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type) { OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), StringToDataLayout(data_layout), StringToLibraryType(library_type)); - OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType()); + OperatorWithKernel::AllOpKernels()[op_type][key] = func; } template @@ -96,7 +97,10 @@ struct OpKernelRegistrarFunctor { void operator()(const char* op_type, const char* library_type) const { using T = typename KERNEL_TYPE::ELEMENT_TYPE; - RegisterKernelClass(op_type, library_type); + RegisterKernelClass( + op_type, library_type, [](const framework::ExecutionContext& ctx) { + KERNEL_TYPE().Compute(ctx); + }); constexpr auto size = std::tuple_size>::value; OpKernelRegistrarFunctor func; @@ -150,7 +154,10 @@ struct OpKernelRegistrarFunctorEx>::type; void operator()(const char* op_type, const char* library_type) const { - RegisterKernelClass(op_type, library_type); + RegisterKernelClass( + op_type, library_type, [](const framework::ExecutionContext& ctx) { + KERNEL_TYPE().Compute(ctx); + }); constexpr auto size = std::tuple_size>::value; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 71cd5a3908..3cf8e8696d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(expected_kernel_key.place_); } - kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx)); + kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 1550d5df17..01d750efbb 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase { class OperatorWithKernel : public OperatorBase { public: + using OpKernelFunc = std::function; using OpKernelMap = - std::unordered_map, - OpKernelType::Hash>; + std::unordered_map; OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs)