/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_GRAPH_OP_DESC_H_ #define INC_GRAPH_OP_DESC_H_ #include #include #include #include #include #include #include #include "detail/attributes_holder.h" #include "graph/range_vistor.h" #define DYNAMIN_INPUT_NAME(name, index) (((name)) + std::to_string((index))) #define DYNAMIN_OUTPUT_NAME(name, index) (((name)) + std::to_string((index))) namespace ge { using std::map; using std::pair; using std::shared_ptr; using std::string; using std::vector; class Operator; class GeTensorDesc; using GeTensorDescPtr = shared_ptr; using ConstGeTensorDescPtr = shared_ptr; class OpDesc; using OpDescPtr = shared_ptr; using ConstOpDescPtr = shared_ptr; class GeAttrValue; using ConstOpDesc = const OpDesc; enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd }; class OpDesc : public std::enable_shared_from_this, public AttrHolder { public: template using Vistor = RangeVistor>; friend class GraphBuilderImpl; friend class OperatorImpl; OpDesc(const string &name, const string &type); OpDesc(); ~OpDesc(); bool operator==(const OpDesc &r_op_desc) const; string GetName() const; void SetName(const string &name); string GetType() const; void SetType(const string &type); graphStatus AddInputDesc(const GeTensorDesc &input_desc); graphStatus AddInputDesc(const string &name, const GeTensorDesc &input_desc); graphStatus AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc); graphStatus AddInputDescForward(const string &name, const unsigned int num); graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); graphStatus AddOutputDescMiddle(const string &name, const unsigned int num, size_t index); graphStatus AddOutputDescForward(const string &name, const unsigned int num); graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); graphStatus UpdateInputDesc(uint32_t index, const GeTensorDesc &tensor_desc); graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc); bool InputIsSet(const string &name) const; GeTensorDesc GetInputDesc(uint32_t index) const; GeTensorDesc GetInputDesc(const string &name) const; Vistor GetAllInputNames() const; GeTensorDescPtr MutableInputDesc(uint32_t index) const; GeTensorDescPtr MutableInputDesc(const string &name) const; Vistor GetAllInputsDesc() const; Vistor GetAllInputsDescPtr() const; size_t GetInputsSize() const; size_t GetAllInputsSize() const; graphStatus AddOutputDesc(const GeTensorDesc &output_desc); graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); graphStatus UpdateOutputDesc(uint32_t index, const GeTensorDesc &tensor_desc); graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc); GeTensorDesc GetOutputDesc(uint32_t index) const; GeTensorDesc GetOutputDesc(const string &name) const; GeTensorDescPtr MutableOutputDesc(uint32_t index) const; GeTensorDescPtr MutableOutputDesc(const string &name) const; uint32_t GetAllOutputsDescSize() const; Vistor GetAllOutputsDesc() const; Vistor GetAllOutputsDescPtr() const; size_t GetOutputsSize() const; ConstGeTensorDescPtr GetOutputDescPtr(uint32_t index) const; ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); bool IsOptionalInput(const string &name) const; bool IsOptionalInput(uint32_t index) const; std::map GetAllInputName() const; std::map GetAllOutputName(); std::map& MutableAllInputName(); std::map& MutableAllOutputName(); bool UpdateInputName(std::map inputNameIdx); bool UpdateOutputName(std::map outputNameIdx); void AddInferFunc(const std::function &func); std::function GetInferFunc() const; graphStatus InferShapeAndType(); void AddInferFormatFunc(const std::function &func); std::function GetInferFormatFunc() const; graphStatus DefaultInferFormat(); std::function GetVerifyFunc() const; void AddVerifierFunc(const std::function &func); graphStatus CallInferFormatFunc(Operator &op); graphStatus OpVerify(); graphStatus CommonVerify() const; graphStatus AddRegisterInputName(const string &name); graphStatus AddRegisterOutputName(const string &name); vector GetRegisterInputName() const; vector GetRegisterOutputName() const; using AttrHolder::AddRequiredAttr; using AttrHolder::DelAttr; using AttrHolder::GetAllAttrNames; using AttrHolder::GetAllAttrs; using AttrHolder::GetAttr; using AttrHolder::HasAttr; using AttrHolder::SetAttr; void SetId(int64_t id); int64_t GetId() const; void SetStreamId(int64_t stream_id); int64_t GetStreamId() const; void SetInputName(const vector &input_name); vector GetInputName() const; void SetSrcName(const vector &src_name); vector GetSrcName() const; void SetSrcIndex(const vector &src_index); vector GetSrcIndex() const; void SetInputOffset(const vector &input); vector GetInputOffset() const; void SetOutputOffset(const vector &input); vector GetOutputOffset() const; void SetDstName(const vector &dst_name); vector GetDstName() const; void SetDstIndex(const vector &dst_index); vector GetDstIndex() const; void SetWorkspace(const vector &workspace); vector GetWorkspace() const; void SetWorkspaceBytes(const vector &workspace_bytes); vector GetWorkspaceBytes() const; void SetIsInputConst(const vector &is_input_const); vector GetIsInputConst() const; void SetOpInferDepends(const vector &depend_names); vector GetOpInferDepends() const; string GetInputNameByIndex(uint32_t index) const; string GetValidInputNameByIndex(uint32_t index) const; int GetValidInputIndexByName(const string &name) const; int GetInputIndexByName(const string &name) const; string GetOutputNameByIndex(uint32_t index) const; int GetOutputIndexByName(const string &name) const; graphStatus RestoreInputNameIdx(const string &name, const int &index); graphStatus RestoreOutputNameIdx(const string &name, const int &index); graphStatus CallInferFunc(Operator &op); void SetOpKernelLibName(const std::string &name); std::string GetOpKernelLibName() const; void SetOpEngineName(const std::string &name); std::string GetOpEngineName() const; void RegisterSubgraphIrName(const std::string &name, SubgraphType type); const std::map &GetSubgraphIrNames() const; SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; graphStatus AddSubgraphName(const std::string &name); const std::map &GetSubgraphNameIndexes() const; std::string GetSubgraphInstanceName(uint32_t index) const; const std::vector &GetSubgraphInstanceNames() const; /// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, /// because this kind of functions will only append a new subgraph instance name /// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. /// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. /// \param index /// \param name /// \return graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); void RemoveSubgraphInstanceName(const std::string &name); graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; graphStatus InferDataSlice(); protected: ProtoAttrMapHelper MutableAttrMap() override; ConstProtoAttrMapHelper GetAttrMap() const override; private: OpDesc(const ProtoMsgOwner &proto_msg_owner, ge::proto::OpDef *op_def); bool OpDescMembersAreEqual(const OpDesc &r_op_desc) const; bool OpDescAttrsAreEqual(const OpDesc &r_op_desc) const; bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; GeIrProtoHelper op_def_; std::vector subgraph_instance_names_; // subgraph names to index, for a `if` operator: // then_branch: 0 // else_branch: 1 // or for a `case` node: // branches0: 0 // branches1: 1 // branches2: 2 std::map subgraph_names_to_index_; // subgraph ir names to type, for a `if` operator: // then_branch: static // else_branch: static // or for a `case` op: // branches: dynamic std::map subgraph_ir_names_to_type_; vector inputs_desc_{}; map input_name_idx_{}; vector register_input_name_{}; std::unordered_set optional_input_names_{}; vector outputs_desc_{}; map output_name_idx_{}; vector register_output_name_{}; std::function infer_func_ = nullptr; std::function infer_format_func_ = nullptr; std::function verifier_func_ = nullptr; std::function infer_data_slice_func_ = nullptr; string op_kernel_lib_name_; string engine_name_; friend class OpDescUtils; friend class ModelSerializeImp; friend class AttrUtils; friend class GeAttrValueImp; friend class OnnxUtils; }; } // namespace ge #endif // INC_GRAPH_OP_DESC_H_