diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index fcc5d7a216..b02a599a80 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -24,9 +24,9 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, const auto& src_inout = src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); auto& dst_inout = *vars; - const OpProto* proto = OpInfoMap().at(src_op->Type()).proto_; + auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto(); const auto& src_arg_list = - src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); + src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); for (const auto& arg : src_arg_list) { if (arg.not_in_gradient() && !is_grad) continue; const std::string src_name = arg.name(); @@ -40,14 +40,8 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, } OperatorBase* BuildGradOp(const OperatorBase* op) { - auto it = OpInfoMap().find(op->Type()); - PADDLE_ENFORCE(it != OpInfoMap().end(), "'%s' has not been registered.", - op->Type()); - PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.", - op->Type()); - std::string grad_op_type = it->second.grad_op_type_; - PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", - op->Type()); + auto& info = OpInfoMap::Instance().Get(op->Type()); + PADDLE_ENFORCE(info.HasGradientOp()); VariableNameMap inputs; VariableNameMap outputs; @@ -56,10 +50,8 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { TransOpArg(op, OpArgType::OUT, true, &inputs); // OG TransOpArg(op, OpArgType::IN, true, &outputs); // IG - it = OpInfoMap().find(grad_op_type); - PADDLE_ENFORCE(it != OpInfoMap().end(), "'%s' has not been registered.", - grad_op_type); - return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs()); + auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_); + return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs()); } } // namespace framework diff --git a/paddle/framework/op_info.cc b/paddle/framework/op_info.cc index f928ac6473..81ba29797c 100644 --- a/paddle/framework/op_info.cc +++ b/paddle/framework/op_info.cc @@ -17,12 +17,11 @@ namespace paddle { namespace framework { -static std::unordered_map* - g_op_info_map = nullptr; -std::unordered_map& OpInfoMap() { +static OpInfoMap* g_op_info_map = nullptr; + +OpInfoMap& OpInfoMap::Instance() { if (g_op_info_map == nullptr) { - g_op_info_map = - new std::unordered_map(); + g_op_info_map = new OpInfoMap(); } return *g_op_info_map; } diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index fdd0ed77d4..94245c6c44 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -34,9 +34,68 @@ struct OpInfo { std::string grad_op_type_; OpProto* proto_; OpAttrChecker* checker_; + + bool HasOpProtoAndChecker() const { + return proto_ != nullptr && checker_ != nullptr; + } + + const OpProto& Proto() const { + PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered"); + PADDLE_ENFORCE(proto_->IsInitialized(), + "Operator Proto must be initialized in op info"); + return *proto_; + } + + const OpAttrChecker& Checker() const { + PADDLE_ENFORCE_NOT_NULL(checker_, + "Operator Checker has not been registered"); + return *checker_; + } + + const OpCreator& Creator() const { + PADDLE_ENFORCE_NOT_NULL(creator_, + "Operator Creator has not been registered"); + return creator_; + } + + bool HasGradientOp() const { return !grad_op_type_.empty(); } }; -extern std::unordered_map& OpInfoMap(); +class OpInfoMap { + public: + static OpInfoMap& Instance(); + + OpInfoMap(const OpInfoMap& o) = delete; + OpInfoMap(OpInfoMap&& o) = delete; + OpInfoMap& operator=(const OpInfoMap& o) = delete; + OpInfoMap& operator=(OpInfoMap&& o) = delete; + + bool Has(const std::string& op_type) const { + return map_.find(op_type) != map_.end(); + } + + void Insert(const std::string& type, const OpInfo& info) { + PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type); + map_.insert({type, info}); + } + + const OpInfo& Get(const std::string& type) const { + auto it = map_.find(type); + PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type); + return it->second; + } + + template + void IterAllInfo(Callback callback) { + for (auto& it : map_) { + callback(it.first, it.second); + } + } + + private: + OpInfoMap() = default; + std::unordered_map map_; +}; } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index e03dc3a73d..b0e85dd49f 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -22,11 +22,9 @@ namespace framework { std::unique_ptr OpRegistry::CreateOp( const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, AttributeMap attrs) { - auto it = OpInfoMap().find(type); - PADDLE_ENFORCE(it != OpInfoMap().end(), - "Operator '%s' has not been registered.", type); - it->second.checker_->Check(attrs); - auto op = it->second.creator_(type, inputs, outputs, attrs); + auto& info = OpInfoMap::Instance().Get(type); + info.Checker().Check(attrs); + auto op = info.Creator()(type, inputs, outputs, attrs); return std::unique_ptr(op); } diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 06530bc7d0..2d09cde41e 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -35,7 +35,7 @@ class OpRegistry { template static void RegisterOp(const std::string& op_type, const std::string& grad_op_type) { - PADDLE_ENFORCE(OpInfoMap().count(op_type) == 0, + PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); OpInfo op_info; op_info.creator_ = []( @@ -59,7 +59,7 @@ class OpRegistry { op_info.proto_ = nullptr; op_info.checker_ = nullptr; } - OpInfoMap().insert(std::make_pair(op_type, op_info)); + OpInfoMap::Instance().Insert(op_type, op_info); // register gradient op if (!grad_op_type.empty()) { RegisterOp(grad_op_type, ""); diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 48a7fe64ac..7abbde610f 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -141,18 +141,10 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { } return ret_val; } - auto it = OpInfoMap().find(type_); - PADDLE_ENFORCE( - it != OpInfoMap().end(), - "Operator %s not registered, cannot figure out intermediate outputs", - type_); - PADDLE_ENFORCE( - it->second.proto_ != nullptr, - "Operator %s has no OpProto, cannot figure out intermediate outputs", - type_); + auto& info = OpInfoMap::Instance().Get(Type()); // get all OpProto::Var for outputs - for (auto& o : it->second.proto_->outputs()) { + for (auto& o : info.Proto().outputs()) { // ignore all intermediate output if (o.intermediate()) continue; auto out = outputs_.find(o.name()); diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 1aec483573..6212c84909 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -138,19 +138,16 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { - auto &op_info_map = OpInfoMap(); std::vector ret_values; - for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) { - const OpProto *proto = it->second.proto_; - if (proto == nullptr) { - continue; - } - PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized"); + + OpInfoMap::Instance().IterAllInfo([&ret_values](const std::string &type, + const OpInfo &info) { + if (!info.HasOpProtoAndChecker()) return; std::string str; - PADDLE_ENFORCE(proto->SerializeToString(&str), + PADDLE_ENFORCE(info.Proto().SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); - ret_values.push_back(py::bytes(str)); - } + ret_values.emplace_back(str); + }); return ret_values; }); m.def_submodule(