|
|
|
@ -33,8 +33,7 @@ namespace framework {
|
|
|
|
|
class OpRegistry {
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType, typename GradOpType>
|
|
|
|
|
static void RegisterOp(const std::string& op_type,
|
|
|
|
|
const std::string& grad_op_type) {
|
|
|
|
|
static void RegisterOp(const std::string& op_type) {
|
|
|
|
|
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
|
|
|
|
|
"'%s' is registered more than once.", op_type);
|
|
|
|
|
OpInfo op_info;
|
|
|
|
@ -43,9 +42,9 @@ class OpRegistry {
|
|
|
|
|
const VariableNameMap& outputs, const AttributeMap& attrs) {
|
|
|
|
|
return new OpType(type, inputs, outputs, attrs);
|
|
|
|
|
};
|
|
|
|
|
op_info.grad_op_type_ = grad_op_type;
|
|
|
|
|
if (std::type_index(typeid(ProtoMakerType)) !=
|
|
|
|
|
std::type_index(typeid(NOPMaker))) {
|
|
|
|
|
op_info.grad_op_type_ = op_type + "_grad";
|
|
|
|
|
op_info.proto_ = new OpProto;
|
|
|
|
|
op_info.checker_ = new OpAttrChecker;
|
|
|
|
|
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
|
|
|
|
@ -55,15 +54,14 @@ class OpRegistry {
|
|
|
|
|
op_info.proto_->IsInitialized(),
|
|
|
|
|
"Fail to initialize %s's OpProto, because %s is not initialized",
|
|
|
|
|
op_type, op_info.proto_->InitializationErrorString());
|
|
|
|
|
// register gradient op
|
|
|
|
|
RegisterOp<GradOpType, NOPMaker, NOP>(op_info.grad_op_type_);
|
|
|
|
|
} else {
|
|
|
|
|
op_info.grad_op_type_ = "";
|
|
|
|
|
op_info.proto_ = nullptr;
|
|
|
|
|
op_info.checker_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
OpInfoMap::Instance().Insert(op_type, op_info);
|
|
|
|
|
// register gradient op
|
|
|
|
|
if (!grad_op_type.empty()) {
|
|
|
|
|
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
@ -92,10 +90,8 @@ class Registrar {
|
|
|
|
|
template <typename OpType, typename ProtoMakerType, typename GradOpType>
|
|
|
|
|
class OpRegistrar : public Registrar {
|
|
|
|
|
public:
|
|
|
|
|
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
|
|
|
|
|
OpRegistrar(const char* op_type, const char* grad_op_type) {
|
|
|
|
|
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
|
|
|
|
|
grad_op_type);
|
|
|
|
|
explicit OpRegistrar(const char* op_type) {
|
|
|
|
|
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -121,8 +117,7 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
/**
|
|
|
|
|
* Macro to register Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
|
|
|
|
|
grad_op_class) \
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
|
|
|
|
|
class _OpClass_##op_type##_ : public op_class { \
|
|
|
|
@ -137,14 +132,14 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
}; \
|
|
|
|
|
static ::paddle::framework::OpRegistrar< \
|
|
|
|
|
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type); \
|
|
|
|
|
int TouchOpRegistrar_##op_type() { \
|
|
|
|
|
__op_registrar_##op_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
|
|
|
|
|
REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
|
|
|
|
|
REGISTER_OP(op_type, op_class, op_maker_class, ::paddle::framework::NOP)
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to register OperatorKernel.
|
|
|
|
|