|
|
|
@ -24,14 +24,27 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/framework/details/op_registry.h"
|
|
|
|
|
#include "paddle/framework/framework.pb.h"
|
|
|
|
|
#include "paddle/framework/grad_op_builder.h"
|
|
|
|
|
#include "paddle/framework/grad_op_desc_maker.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
class Registrar {
|
|
|
|
|
public:
|
|
|
|
|
// In our design, various kinds of classes, e.g., operators and kernels,
|
|
|
|
|
// have their corresponding registry and registrar. The action of
|
|
|
|
|
// registration is in the constructor of a global registrar variable, which,
|
|
|
|
|
// however, are not used in the code that calls package framework, and would
|
|
|
|
|
// be removed from the generated binary file by the linker. To avoid such
|
|
|
|
|
// removal, we add Touch to all registrar classes and make USE_OP macros to
|
|
|
|
|
// call this method. So, as long as the callee code calls USE_OP, the global
|
|
|
|
|
// registrar variable won't be removed by the linker.
|
|
|
|
|
void Touch() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename... ARGS>
|
|
|
|
|
struct OperatorRegistrar {
|
|
|
|
|
struct OperatorRegistrar : public Registrar {
|
|
|
|
|
explicit OperatorRegistrar(const char* op_type) : op_type(op_type) {
|
|
|
|
|
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
|
|
|
|
|
"'%s' is registered more than once.", op_type);
|
|
|
|
@ -70,19 +83,6 @@ class OpRegistry {
|
|
|
|
|
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Registrar {
|
|
|
|
|
public:
|
|
|
|
|
// In our design, various kinds of classes, e.g., operators and kernels,
|
|
|
|
|
// have their corresponding registry and registrar. The action of
|
|
|
|
|
// registration is in the constructor of a global registrar variable, which,
|
|
|
|
|
// however, are not used in the code that calls package framework, and would
|
|
|
|
|
// be removed from the generated binary file by the linker. To avoid such
|
|
|
|
|
// removal, we add Touch to all registrar classes and make USE_OP macros to
|
|
|
|
|
// call this method. So, as long as the callee code calls USE_OP, the global
|
|
|
|
|
// registrar variable won't be removed by the linker.
|
|
|
|
|
void Touch() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename OpType, typename ProtoMakerType, typename GradOpType>
|
|
|
|
|
class OpRegistrar : public Registrar {
|
|
|
|
|
public:
|
|
|
|
@ -138,33 +138,43 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
__test_global_namespace_##uniq_name##__>::value, \
|
|
|
|
|
msg)
|
|
|
|
|
|
|
|
|
|
#define VA_ARGS(...) , ##__VA_ARGS__
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OPERATOR(op_type, op_class, ...) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op__##op_type, \
|
|
|
|
|
"REGISTER_OPERATOR must be called in global namespace"); \
|
|
|
|
|
class _OpClass_##op_type##_ : public op_class { \
|
|
|
|
|
public: \
|
|
|
|
|
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
|
|
|
|
|
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
|
|
|
|
|
}; \
|
|
|
|
|
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_ VA_ARGS( \
|
|
|
|
|
__VA_ARGS__)> \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type); \
|
|
|
|
|
int TouchOpRegistrar_##op_type() { \
|
|
|
|
|
__op_registrar_##op_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to register Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
|
|
|
|
|
grad_op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
|
|
|
|
|
class _OpClass_##op_type##_ : public op_class { \
|
|
|
|
|
public: \
|
|
|
|
|
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
|
|
|
|
|
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
|
|
|
|
|
}; \
|
|
|
|
|
class _OpGradClass_##op_type##_ : public grad_op_class { \
|
|
|
|
|
public: \
|
|
|
|
|
DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \
|
|
|
|
|
DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \
|
|
|
|
|
}; \
|
|
|
|
|
static ::paddle::framework::OpRegistrar< \
|
|
|
|
|
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
|
|
|
|
|
int TouchOpRegistrar_##op_type() { \
|
|
|
|
|
__op_registrar_##op_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
|
|
|
|
|
grad_op_class) \
|
|
|
|
|
REGISTER_OPERATOR(grad_op_type, grad_op_class); \
|
|
|
|
|
class _GradOpDescMaker_##grad_op_type##_ \
|
|
|
|
|
: public ::paddle::framework::DefaultGradOpDescMaker { \
|
|
|
|
|
using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \
|
|
|
|
|
\
|
|
|
|
|
protected: \
|
|
|
|
|
virtual std::string GradOpType() const { return #grad_op_type; } \
|
|
|
|
|
}; \
|
|
|
|
|
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
|
|
|
|
|
op_maker_class)
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
|
|
|
|
|
REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
|
|
|
|
|
REGISTER_OPERATOR(op_type, op_class, op_maker_class)
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to register OperatorKernel.
|
|
|
|
|