|
|
|
@ -222,7 +222,7 @@ class OpRegistry {
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
static void RegisterOp(const std::string& op_type) {
|
|
|
|
|
creators()[op_type] = [] { return new OpType; };
|
|
|
|
|
op_creators()[op_type] = [] { return new OpType; };
|
|
|
|
|
OpAttrChecker& op_checker = op_checkers()[op_type];
|
|
|
|
|
OpProto& op_proto = protos()[op_type];
|
|
|
|
|
auto maker = ProtoMakerType(&op_proto, &op_checker);
|
|
|
|
@ -245,17 +245,19 @@ class OpRegistry {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename OpType>
|
|
|
|
|
static void RegisterGradOp(const std::string& op_type) {
|
|
|
|
|
grad_creators()[op_type] = [] { return new OpType; };
|
|
|
|
|
template <typename GradOpType>
|
|
|
|
|
static void RegisterGradOp(const std::string& op_type,
|
|
|
|
|
const std::string& grad_op_type) {
|
|
|
|
|
op_creators()[grad_op_type] = [] { return new GradOpType; };
|
|
|
|
|
grad_ops()[op_type] = grad_op_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
|
const VarNameList& inputs,
|
|
|
|
|
const VarNameList& outputs,
|
|
|
|
|
const AttributeMap& attrs) {
|
|
|
|
|
auto op_create_it = creators().find(type);
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != creators().end(),
|
|
|
|
|
auto op_create_it = op_creators().find(type);
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != op_creators().end(),
|
|
|
|
|
"Operator %s cannot be found.", type);
|
|
|
|
|
|
|
|
|
|
auto op = op_create_it->second();
|
|
|
|
@ -300,8 +302,8 @@ class OpRegistry {
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateGradOp(
|
|
|
|
|
std::shared_ptr<OperatorBase> op) {
|
|
|
|
|
GradOpCreator creator(op.get());
|
|
|
|
|
std::shared_ptr<OperatorBase> grad_op(creator.Create());
|
|
|
|
|
GradOpBuilder builder(op.get());
|
|
|
|
|
std::shared_ptr<OperatorBase> grad_op(builder.Build());
|
|
|
|
|
grad_op->Init();
|
|
|
|
|
return grad_op;
|
|
|
|
|
}
|
|
|
|
@ -311,9 +313,9 @@ class OpRegistry {
|
|
|
|
|
return protos_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpCreator>& grad_creators() {
|
|
|
|
|
static std::unordered_map<std::string, OpCreator> grad_creators_;
|
|
|
|
|
return grad_creators_;
|
|
|
|
|
static std::unordered_map<std::string, std::string>& grad_ops() {
|
|
|
|
|
static std::unordered_map<std::string, std::string> grad_ops_;
|
|
|
|
|
return grad_ops_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
|
|
|
|
@ -322,12 +324,12 @@ class OpRegistry {
|
|
|
|
|
return maps_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static std::unordered_map<std::string, OpCreator>& creators() {
|
|
|
|
|
static std::unordered_map<std::string, OpCreator> creators_;
|
|
|
|
|
return creators_;
|
|
|
|
|
static std::unordered_map<std::string, OpCreator>& op_creators() {
|
|
|
|
|
static std::unordered_map<std::string, OpCreator> op_creators_;
|
|
|
|
|
return op_creators_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
|
|
|
|
|
return op_checkers_;
|
|
|
|
@ -353,11 +355,11 @@ class OpRegisterHelper {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename OpType>
|
|
|
|
|
template <typename GradOpType>
|
|
|
|
|
class GradOpRegisterHelper {
|
|
|
|
|
public:
|
|
|
|
|
GradOpRegisterHelper(const char* op_type) {
|
|
|
|
|
OpRegistry::RegisterGradOp<OpType>(op_type);
|
|
|
|
|
GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
|
|
|
|
|
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -383,13 +385,16 @@ class GradOpRegisterHelper {
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register Gradient Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_gradient_op__##__op_type, \
|
|
|
|
|
"REGISTER_GRADIENT_OP must be in global namespace"); \
|
|
|
|
|
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
|
|
|
|
|
__op_gradient_register_##__op_type##__(#__op_type); \
|
|
|
|
|
int __op_gradient_register_##__op_type##_handle__() { return 0; }
|
|
|
|
|
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_gradient_op__##__op_type##__grad_op_type, \
|
|
|
|
|
"REGISTER_GRADIENT_OP must be in global namespace"); \
|
|
|
|
|
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
|
|
|
|
|
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
|
|
|
|
|
#__grad_op_type); \
|
|
|
|
|
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register OperatorKernel.
|
|
|
|
|