|
|
|
@ -193,7 +193,7 @@ class OpRegistry {
|
|
|
|
|
using VarNameList = std::vector<std::string>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
template <typename OpType, typename ProtoMakerType, typename GradOpType>
|
|
|
|
|
static void RegisterOp(const std::string& op_type,
|
|
|
|
|
const std::string& grad_op_type) {
|
|
|
|
|
PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
|
|
|
|
@ -226,6 +226,10 @@ class OpRegistry {
|
|
|
|
|
// ================================================ //
|
|
|
|
|
}
|
|
|
|
|
op_info_map().insert(std::make_pair(op_type, op_info));
|
|
|
|
|
// register gradient op
|
|
|
|
|
if (!grad_op_type.empty()) {
|
|
|
|
|
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
@ -321,12 +325,13 @@ class Registrar {
|
|
|
|
|
void Touch() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
template <typename OpType, typename ProtoMakerType, typename GradOpType>
|
|
|
|
|
class OpRegistrar : public Registrar {
|
|
|
|
|
public:
|
|
|
|
|
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
|
|
|
|
|
OpRegistrar(const char* op_type, const char* grad_op_type) {
|
|
|
|
|
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type, grad_op_type);
|
|
|
|
|
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
|
|
|
|
|
grad_op_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -352,10 +357,12 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
/**
|
|
|
|
|
* Macro to register Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type) \
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
|
|
|
|
|
grad_op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
|
|
|
|
|
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
|
|
|
|
|
static ::paddle::framework::OpRegistrar<op_class, op_maker_class, \
|
|
|
|
|
grad_op_class> \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
|
|
|
|
|
int TouchOpRegistrar_##op_type() { \
|
|
|
|
|
__op_registrar_##op_type##__.Touch(); \
|
|
|
|
@ -363,10 +370,7 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
|
|
|
|
|
REGISTER_OP(op_type, op_class, op_maker_class, )
|
|
|
|
|
|
|
|
|
|
#define REGISTER_GRADIENT_OP(op_type, op_class) \
|
|
|
|
|
REGISTER_OP(op_type, op_class, ::paddle::framework::NOPMaker, )
|
|
|
|
|
REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to register OperatorKernel.
|
|
|
|
|