|
|
|
@ -76,13 +76,8 @@ class OpRegistry {
|
|
|
|
|
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
|
|
|
|
|
struct OpKernelRegistrarFunctor;
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, size_t I, typename... KernelTypes>
|
|
|
|
|
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
|
|
|
|
|
using KERNEL_TYPE =
|
|
|
|
|
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
|
|
|
|
|
|
|
|
|
|
void operator()(const char* op_type, const char* library_type) const {
|
|
|
|
|
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
|
|
|
|
|
template <typename PlaceType, typename T, typename KernelType>
|
|
|
|
|
inline void RegisterKernelClass(const char* op_type, const char* library_type) {
|
|
|
|
|
std::string library(library_type);
|
|
|
|
|
std::string data_layout = "ANYLAYOUT";
|
|
|
|
|
if (library == "MKLDNN") {
|
|
|
|
@ -91,8 +86,17 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
|
|
|
|
|
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
|
|
|
|
|
StringToDataLayout(data_layout),
|
|
|
|
|
StringToLibraryType(library_type));
|
|
|
|
|
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
|
|
|
|
|
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, size_t I, typename... KernelTypes>
|
|
|
|
|
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
|
|
|
|
|
using KERNEL_TYPE =
|
|
|
|
|
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
|
|
|
|
|
|
|
|
|
|
void operator()(const char* op_type, const char* library_type) const {
|
|
|
|
|
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
|
|
|
|
|
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
|
|
|
|
|
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
|
|
|
|
|
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
|
|
|
|
|
func;
|
|
|
|
@ -116,6 +120,47 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
|
|
|
|
|
struct OpKernelRegistrarFunctorEx;
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, typename... DataTypeAndKernelType>
|
|
|
|
|
class OpKernelRegistrarEx : public Registrar {
|
|
|
|
|
public:
|
|
|
|
|
explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) {
|
|
|
|
|
OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
|
|
|
|
|
func;
|
|
|
|
|
func(op_type, library_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
|
|
|
|
|
struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
|
|
|
|
|
DataTypeAndKernelType...> {
|
|
|
|
|
void operator()(const char* op_type, const char* library_type) const {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
|
|
|
|
|
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
|
|
|
|
|
DataTypeAndKernelType...> {
|
|
|
|
|
using KERNEL_TYPE =
|
|
|
|
|
typename std::tuple_element<I + 1,
|
|
|
|
|
std::tuple<DataTypeAndKernelType...>>::type;
|
|
|
|
|
using T =
|
|
|
|
|
typename std::tuple_element<I,
|
|
|
|
|
std::tuple<DataTypeAndKernelType...>>::type;
|
|
|
|
|
|
|
|
|
|
void operator()(const char* op_type, const char* library_type) const {
|
|
|
|
|
RegisterKernelClass<PlaceType, T, KERNEL_TYPE>(op_type, library_type);
|
|
|
|
|
|
|
|
|
|
constexpr auto size =
|
|
|
|
|
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
|
|
|
|
|
OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
|
|
|
|
|
DataTypeAndKernelType...>
|
|
|
|
|
func;
|
|
|
|
|
func(op_type, library_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* check if MACRO is used in GLOBAL NAMESPACE.
|
|
|
|
|
*/
|
|
|
|
@ -174,6 +219,25 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
|
|
|
|
|
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op_kernel_##op_type##_##library_type##__, \
|
|
|
|
|
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \
|
|
|
|
|
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
|
|
|
|
|
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
|
|
|
|
|
#library_type); \
|
|
|
|
|
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
|
|
|
|
|
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_CUDA_KERNEL_EX(op_type, ...) \
|
|
|
|
|
REGISTER_OP_KERNEL_EX(p_type, CUDA, ::paddle::platform::CUDAPlace, \
|
|
|
|
|
__VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_CPU_KERNEL_EX(op_type, ...) \
|
|
|
|
|
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to mark what Operator and Kernel
|
|
|
|
|
* we will use and tell the compiler to
|
|
|
|
|