|
|
|
@ -311,7 +311,7 @@ class OpRegisterHelper {
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register OperatorKernel.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType) \
|
|
|
|
|
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \
|
|
|
|
|
"REGISTER_OP_KERNEL must be in global namespace"); \
|
|
|
|
@ -320,17 +320,19 @@ class OpRegisterHelper {
|
|
|
|
|
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
|
|
|
|
|
key.place_ = PlaceType(); \
|
|
|
|
|
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
|
|
|
|
|
.reset(new KernelType()); \
|
|
|
|
|
.reset(new __VA_ARGS__()); \
|
|
|
|
|
} \
|
|
|
|
|
}; \
|
|
|
|
|
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
|
|
|
|
|
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_GPU_KERNEL(type, KernelType) \
|
|
|
|
|
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType)
|
|
|
|
|
// (type, KernelType)
|
|
|
|
|
#define REGISTER_OP_GPU_KERNEL(type, ...) \
|
|
|
|
|
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_CPU_KERNEL(type, KernelType) \
|
|
|
|
|
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType)
|
|
|
|
|
// (type, KernelType)
|
|
|
|
|
#define REGISTER_OP_CPU_KERNEL(type, ...) \
|
|
|
|
|
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to mark what Operator and Kernel we will use and tell the compiler to
|
|
|
|
|