|
|
|
@ -309,6 +309,14 @@ class OpRegistry {
|
|
|
|
|
|
|
|
|
|
class Registrar {
|
|
|
|
|
public:
|
|
|
|
|
// In our design, various kinds of classes, e.g., operators and kernels, have
|
|
|
|
|
// their corresponding registry and registrar. The action of registration is
|
|
|
|
|
// in the constructor of a global registrar variable, which, however, are not
|
|
|
|
|
// used in the code that calls package framework, and would be removed from
|
|
|
|
|
// the generated binary file by the linker. To avoid such removal, we add
|
|
|
|
|
// Touch to all registrar classes and make USE_OP macros to call this
|
|
|
|
|
// method. So, as long as the callee code calls USE_OP, the global
|
|
|
|
|
// registrar variable won't be removed by the linker.
|
|
|
|
|
void Touch() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -332,10 +340,9 @@ template <typename PlaceType, typename KernelType>
|
|
|
|
|
class OpKernelRegistrar : public Registrar {
|
|
|
|
|
public:
|
|
|
|
|
explicit OpKernelRegistrar(const char* op_type) {
|
|
|
|
|
::paddle::framework::OperatorWithKernel::OpKernelKey key;
|
|
|
|
|
OperatorWithKernel::OpKernelKey key;
|
|
|
|
|
key.place_ = PlaceType();
|
|
|
|
|
::paddle::framework::OperatorWithKernel::AllOpKernels()[op_type][key].reset(
|
|
|
|
|
new KernelType);
|
|
|
|
|
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -349,7 +356,7 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
msg)
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register Operator.
|
|
|
|
|
* Macro to register Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
@ -362,7 +369,7 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register Gradient Operator.
|
|
|
|
|
* Macro to register Gradient Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
@ -377,7 +384,7 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register OperatorKernel.
|
|
|
|
|
* Macro to register OperatorKernel.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|