|
|
|
@ -76,8 +76,9 @@ class OpRegistry {
|
|
|
|
|
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
|
|
|
|
|
struct OpKernelRegistrarFunctor;
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, typename T, typename KernelType>
|
|
|
|
|
inline void RegisterKernelClass(const char* op_type, const char* library_type) {
|
|
|
|
|
template <typename PlaceType, typename T, typename Func>
|
|
|
|
|
inline void RegisterKernelClass(const char* op_type, const char* library_type,
|
|
|
|
|
Func func) {
|
|
|
|
|
std::string library(library_type);
|
|
|
|
|
std::string data_layout = "ANYLAYOUT";
|
|
|
|
|
if (library == "MKLDNN") {
|
|
|
|
@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type) {
|
|
|
|
|
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
|
|
|
|
|
StringToDataLayout(data_layout),
|
|
|
|
|
StringToLibraryType(library_type));
|
|
|
|
|
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType());
|
|
|
|
|
OperatorWithKernel::AllOpKernels()[op_type][key] = func;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, size_t I, typename... KernelTypes>
|
|
|
|
@ -96,7 +97,10 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
|
|
|
|
|
|
|
|
|
|
void operator()(const char* op_type, const char* library_type) const {
|
|
|
|
|
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
|
|
|
|
|
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
|
|
|
|
|
RegisterKernelClass<PlaceType, T>(
|
|
|
|
|
op_type, library_type, [](const framework::ExecutionContext& ctx) {
|
|
|
|
|
KERNEL_TYPE().Compute(ctx);
|
|
|
|
|
});
|
|
|
|
|
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
|
|
|
|
|
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
|
|
|
|
|
func;
|
|
|
|
@ -150,7 +154,10 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
|
|
|
|
|
std::tuple<DataTypeAndKernelType...>>::type;
|
|
|
|
|
|
|
|
|
|
void operator()(const char* op_type, const char* library_type) const {
|
|
|
|
|
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
|
|
|
|
|
RegisterKernelClass<PlaceType, T>(
|
|
|
|
|
op_type, library_type, [](const framework::ExecutionContext& ctx) {
|
|
|
|
|
KERNEL_TYPE().Compute(ctx);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
constexpr auto size =
|
|
|
|
|
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
|
|
|
|
|