/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_GRAPH_COMPUTE_GRAPH_H_ #define INC_GRAPH_COMPUTE_GRAPH_H_ #include #include #include #include #include #include #include "detail/attributes_holder.h" #include "graph/anchor.h" #include "graph/node.h" #include "graph/op_desc.h" #include "graph/range_vistor.h" namespace ge { class Node; using NodePtr = std::shared_ptr; class Edge; using EdgePtr = std::shared_ptr; class InDataAnchor; using InDataAnchorPtr = std::shared_ptr; class OutDataAnchor; using OutDataAnchorPtr = std::shared_ptr; class ControlAnchor; using ControlAnchorPtr = std::shared_ptr; class InControlAnchor; using InControlAnchorPtr = std::shared_ptr; class OutControlAnchor; using OutControlAnchorPtr = std::shared_ptr; class GeAttrValue; using AttrValuePtr = std::shared_ptr; using ConstComputeGraph = const ComputeGraph; class OperatorImpl; using OperatorImplPtr = std::shared_ptr; class ComputeGraph : public std::enable_shared_from_this, public AttrHolder { friend class GraphUtils; public: template using Vistor = RangeVistor>; explicit ComputeGraph(const std::string &name); virtual ~ComputeGraph(); std::string GetName() const; void SetName(const std::string &name); using AttrHolder::DelAttr; using AttrHolder::GetAttr; using AttrHolder::HasAttr; using AttrHolder::SetAttr; size_t GetAllNodesSize() const; Vistor GetAllNodes() const; size_t GetDirectNodesSize() const; Vistor GetDirectNode() const; Vistor GetInputNodes() const; Vistor GetOutputNodes() const; NodePtr FindNode(const std::string &name) const; // Add node NodePtr AddNode(NodePtr node); NodePtr AddNode(OpDescPtr op); NodePtr AddNodeFront(NodePtr node); NodePtr AddNodeFront(const OpDescPtr &op); NodePtr AddInputNode(NodePtr node); NodePtr AddOutputNode(NodePtr node); graphStatus RemoveNode(const NodePtr &node); graphStatus RemoveInputNode(const NodePtr &node); graphStatus RemoveOutputNode(const NodePtr &node); graphStatus RemoveConstInput(const NodePtr &node); std::shared_ptr AddSubGraph(std::shared_ptr sub_graph); graphStatus RemoveSubGraph(const std::shared_ptr &sub_graph); graphStatus TopologicalSorting(); bool IsValid() const; void Dump() const; graphStatus IsolateNode(const NodePtr &node); graphStatus Verify(); graphStatus InferShape(); graphStatus InferOriginFormat(); graphStatus InferShapeInNeed(); graphStatus InsertEventNodes(); bool operator==(const ComputeGraph &r_compute_graph) const; const std::map, std::vector> &GetShareParamLayer() const { return params_share_map_; } void SetShareParamLayer(const std::map, std::vector> params_share_map) { params_share_map_ = params_share_map; } void SetInputsOrder(const std::vector &inputs_order) { inputs_order_ = inputs_order; } void SetGraphOutNodes(std::map> out_nodes_map) { out_nodes_map_ = out_nodes_map; } void AppendGraphOutNodes(std::map> out_nodes_map) { for (auto &item : out_nodes_map) { (void)out_nodes_map_.emplace(item.first, item.second); } } const std::map> &GetGraphOutNodes() const { return out_nodes_map_; } void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } ComputeGraphPtr GetOrigGraph(void) { return origGraph_; } void SetOutputSize(uint32_t size) { output_size_ = size; } uint32_t GetOutputSize() const { return output_size_; } void SetInputSize(uint32_t size) { input_size_ = size; } uint32_t GetInputSize() const { return input_size_; } /// /// Set iteration needed. /// If set is true, it means this graph need run iteration some /// times(according variant "npu_runconfig/iterations_per_loop"). /// @param need_iteration is need iteration /// void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; } void SetUserDefOutput(const std::string &output_name); const std::string GetOutput(); /// /// Get need_iteration. /// @return is need iteration /// bool GetNeedIteration() const { return need_iteration_; } void SetGraphOpName(const std::map &op_name_map) { op_name_map_ = op_name_map; } const std::map &GetGraphOpName() const { return op_name_map_; } const std::map &GetAllNodesInfo() const; void SetAllNodesInfo(const std::map &nodes) { all_nodes_infos_ = nodes; } void SetGraphOutNodesInfo(std::vector> &out_nodes_info) { output_nodes_info_ = out_nodes_info; } void AppendGraphOutNodesInfo(std::vector> &out_nodes_info) { output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end()); } const std::vector> &GetGraphOutNodesInfo() const { return output_nodes_info_; } void SetGraphTargetNodesInfo(const std::vector &target_nodes_info) { target_nodes_info_ = target_nodes_info; } const std::vector &GetGraphTargetNodesInfo() const { return target_nodes_info_; } void SetSessionID(uint64_t session_id) { session_id_ = session_id; } uint64_t GetSessionID() const { return session_id_; } void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; } uint32_t GetGraphID() const { return graph_id_; } void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; } ge::Format GetDataFormat() const { return data_format_; } bool IsSummaryGraph() const { return is_summary_graph_; } void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; } // Graph Before BFE ComputeGraphPtr origGraph_; protected: ProtoAttrMapHelper MutableAttrMap() override; ConstProtoAttrMapHelper GetAttrMap() const override; private: graphStatus DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, std::vector &stack); graphStatus BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, std::deque &stack); graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, std::map &breadth_node_map); graphStatus SortNodes(std::vector &stack, std::map &mapInEdgeNum); size_t GetInEdgeSize(const NodePtr &node); size_t GetOutEdgeSize(const NodePtr &node); graphStatus RemoveExtraOutEdge(const NodePtr &node); bool GraphMembersAreEqual(const ComputeGraph &r_graph) const; bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const; bool VectorInputNodePtrIsEqual(const std::vector &r_node_ptr_vector, const std::vector &l_node_ptr_vector) const; ProtoAttrMapHelper attrs_; friend class ModelSerializeImp; friend class GraphDebugImp; friend class OnnxUtils; std::vector nodes_; std::vector input_nodes_; std::vector> sub_graph_; std::string name_; bool is_valid_flag_; bool is_summary_graph_ = false; // Indicates whether it is need iteration bool need_iteration_ = false; std::map, std::vector> params_share_map_; std::map> out_nodes_map_; // TaskIdx -> op_name Map std::map op_name_map_; std::vector inputs_order_; uint32_t output_size_ = 1; uint32_t input_size_ = 1; std::map all_nodes_infos_; std::vector> output_nodes_info_; std::vector target_nodes_info_; uint64_t session_id_ = 0; uint32_t graph_id_ = 0; ge::Format data_format_ = ge::FORMAT_ND; }; } // namespace ge #endif // INC_GRAPH_COMPUTE_GRAPH_H_