|
|
|
@ -137,23 +137,21 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
__test_global_namespace_##uniq_name##__>::value, \
|
|
|
|
|
msg)
|
|
|
|
|
|
|
|
|
|
#define VA_ARGS(...) , ##__VA_ARGS__
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OPERATOR(op_type, op_class, ...) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op__##op_type, \
|
|
|
|
|
"REGISTER_OPERATOR must be called in global namespace"); \
|
|
|
|
|
class _OpClass_##op_type##_ : public op_class { \
|
|
|
|
|
public: \
|
|
|
|
|
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
|
|
|
|
|
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
|
|
|
|
|
}; \
|
|
|
|
|
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_ VA_ARGS( \
|
|
|
|
|
__VA_ARGS__)> \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type); \
|
|
|
|
|
int TouchOpRegistrar_##op_type() { \
|
|
|
|
|
__op_registrar_##op_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
#define REGISTER_OPERATOR(op_type, op_class, ...) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op__##op_type, \
|
|
|
|
|
"REGISTER_OPERATOR must be called in global namespace"); \
|
|
|
|
|
class _OpClass_##op_type##_ : public op_class { \
|
|
|
|
|
public: \
|
|
|
|
|
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
|
|
|
|
|
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
|
|
|
|
|
}; \
|
|
|
|
|
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \
|
|
|
|
|
##__VA_ARGS__> \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type); \
|
|
|
|
|
int TouchOpRegistrar_##op_type() { \
|
|
|
|
|
__op_registrar_##op_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -170,7 +168,7 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
virtual std::string GradOpType() const { return #grad_op_type; } \
|
|
|
|
|
}; \
|
|
|
|
|
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
|
|
|
|
|
op_maker_class)
|
|
|
|
|
op_maker_class);
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
|
|
|
|
|
REGISTER_OPERATOR(op_type, op_class, op_maker_class)
|
|
|
|
|