|
|
|
@ -126,13 +126,6 @@ class NOPMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct OpInfo {
|
|
|
|
|
std::function<OperatorBase*()> creator_;
|
|
|
|
|
std::string grad_op_type_;
|
|
|
|
|
OpProto* proto_;
|
|
|
|
|
OpAttrChecker* checker_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpRegistry {
|
|
|
|
|
using VarNameMap = OperatorBase::VarNameMap;
|
|
|
|
|
using OpCreator = std::function<OperatorBase*(
|
|
|
|
@ -140,6 +133,13 @@ class OpRegistry {
|
|
|
|
|
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
struct OpInfo {
|
|
|
|
|
OpCreator creator_;
|
|
|
|
|
std::string grad_op_type_;
|
|
|
|
|
OpProto* proto_;
|
|
|
|
|
OpAttrChecker* checker_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename OpType, typename ProtoMakerType, typename GradOpType>
|
|
|
|
|
static void RegisterOp(const std::string& op_type,
|
|
|
|
|
const std::string& grad_op_type) {
|
|
|
|
@ -175,9 +175,9 @@ class OpRegistry {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
|
const VarNameList& inputs,
|
|
|
|
|
const VarNameList& outputs,
|
|
|
|
|
const AttributeMap& attrs) {
|
|
|
|
|
const VarNameMap& inputs,
|
|
|
|
|
const VarNameMap& outputs,
|
|
|
|
|
AttributeMap attrs) {
|
|
|
|
|
auto it = op_info_map().find(type);
|
|
|
|
|
PADDLE_ENFORCE(it != op_info_map().end(),
|
|
|
|
|
"Operator '%s' has not been registered.", type);
|
|
|
|
|