|
|
|
@ -307,10 +307,7 @@ class OpRegistry {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Registrar {
|
|
|
|
|
public:
|
|
|
|
|
void Touch() {}
|
|
|
|
|
};
|
|
|
|
|
class Registrar {};
|
|
|
|
|
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
class OpRegistrar : public Registrar {
|
|
|
|
@ -354,40 +351,37 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
|
|
|
|
|
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type); \
|
|
|
|
|
int TouchOpRegistrar_##op_type() { \
|
|
|
|
|
__op_registrar_##op_type##__.Touch(); \
|
|
|
|
|
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register Gradient Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_gradient_op__##op_type##_##grad_op_type, \
|
|
|
|
|
"REGISTER_GRADIENT_OP must be called in global namespace"); \
|
|
|
|
|
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
|
|
|
|
|
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
|
|
|
|
|
#grad_op_type); \
|
|
|
|
|
int TouchOpGradientRegistrar_##op_type() { \
|
|
|
|
|
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_gradient_op__##op_type##_##grad_op_type, \
|
|
|
|
|
"REGISTER_GRADIENT_OP must be called in global namespace"); \
|
|
|
|
|
int TouchOpGradientRegistrar_##op_type() { \
|
|
|
|
|
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
|
|
|
|
|
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
|
|
|
|
|
#grad_op_type); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register OperatorKernel.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
|
|
|
|
|
"REGISTER_OP_KERNEL must be called in global namespace"); \
|
|
|
|
|
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
|
|
|
|
|
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
|
|
|
|
|
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
|
|
|
|
|
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
|
|
|
|
|
"REGISTER_OP_KERNEL must be called in global namespace"); \
|
|
|
|
|
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
|
|
|
|
|
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
|
|
|
|
|
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|