|
|
@ -79,30 +79,31 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
|
|
|
|
using KERNEL_TYPE =
|
|
|
|
using KERNEL_TYPE =
|
|
|
|
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
|
|
|
|
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
|
|
|
|
|
|
|
|
|
|
|
|
void operator()(const char* op_type) const {
|
|
|
|
void operator()(const char* op_type, const char* library_type) const {
|
|
|
|
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
|
|
|
|
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
|
|
|
|
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType());
|
|
|
|
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
|
|
|
|
|
|
|
|
DataLayout::kAnyLayout, StringToLibraryType(library_type));
|
|
|
|
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
|
|
|
|
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
|
|
|
|
|
|
|
|
|
|
|
|
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
|
|
|
|
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
|
|
|
|
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
|
|
|
|
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
|
|
|
|
func;
|
|
|
|
func;
|
|
|
|
func(op_type);
|
|
|
|
func(op_type, library_type);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, size_t I, typename... KernelType>
|
|
|
|
template <typename PlaceType, size_t I, typename... KernelType>
|
|
|
|
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
|
|
|
|
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
|
|
|
|
void operator()(const char* op_type) const {}
|
|
|
|
void operator()(const char* op_type, const char* library_type) const {}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// User can register many kernel in one place. The data type could be different.
|
|
|
|
// User can register many kernel in one place. The data type could be different.
|
|
|
|
template <typename PlaceType, typename... KernelType>
|
|
|
|
template <typename PlaceType, typename... KernelType>
|
|
|
|
class OpKernelRegistrar : public Registrar {
|
|
|
|
class OpKernelRegistrar : public Registrar {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
explicit OpKernelRegistrar(const char* op_type) {
|
|
|
|
explicit OpKernelRegistrar(const char* op_type, const char* library_type) {
|
|
|
|
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
|
|
|
|
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
|
|
|
|
func(op_type);
|
|
|
|
func(op_type, library_type);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -181,7 +182,8 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
|
|
|
|
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
|
|
|
|
"REGISTER_OP_KERNEL must be called in global namespace"); \
|
|
|
|
"REGISTER_OP_KERNEL must be called in global namespace"); \
|
|
|
|
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
|
|
|
|
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
|
|
|
|
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
|
|
|
|
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type, \
|
|
|
|
|
|
|
|
#DEVICE_TYPE); \
|
|
|
|
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
|
|
|
|
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
|
|
|
|
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
|
|
|
|
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
|
|
|
|
return 0; \
|
|
|
|
return 0; \
|
|
|
|