You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/inc/graph/compute_graph.h

306 lines
12 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_GRAPH_COMPUTE_GRAPH_H_
#define INC_GRAPH_COMPUTE_GRAPH_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <deque>
#include "detail/attributes_holder.h"
#include "graph/anchor.h"
#include "graph/node.h"
#include "graph/op_desc.h"
#include "graph/range_vistor.h"
namespace ge {
class Node;
using NodePtr = std::shared_ptr<Node>;
class Edge;
using EdgePtr = std::shared_ptr<Edge>;
class InDataAnchor;
using InDataAnchorPtr = std::shared_ptr<InDataAnchor>;
class OutDataAnchor;
using OutDataAnchorPtr = std::shared_ptr<OutDataAnchor>;
class ControlAnchor;
using ControlAnchorPtr = std::shared_ptr<ControlAnchor>;
class InControlAnchor;
using InControlAnchorPtr = std::shared_ptr<InControlAnchor>;
class OutControlAnchor;
using OutControlAnchorPtr = std::shared_ptr<OutControlAnchor>;
class GeAttrValue;
using AttrValuePtr = std::shared_ptr<GeAttrValue>;
using ConstComputeGraph = const ComputeGraph;
class OperatorImpl;
using OperatorImplPtr = std::shared_ptr<OperatorImpl>;
class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public AttrHolder {
friend class GraphUtils;
public:
template <class T>
using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>;
explicit ComputeGraph(const std::string &name);
~ComputeGraph() override;
std::string GetName() const;
void SetName(const std::string &name);
using AttrHolder::DelAttr;
using AttrHolder::GetAttr;
using AttrHolder::HasAttr;
using AttrHolder::SetAttr;
size_t GetAllNodesSize() const;
Vistor<NodePtr> GetAllNodes() const;
// is_unknown_shape: false, same with GetAllNodes func
// is_unknown_shape: true, same with GetDirectNodes func
Vistor<NodePtr> GetNodes(bool is_unknown_shape) const;
size_t GetDirectNodesSize() const;
Vistor<NodePtr> GetDirectNode() const;
Vistor<NodePtr> GetInputNodes() const;
Vistor<NodePtr> GetOutputNodes() const;
NodePtr FindNode(const std::string &name) const;
NodePtr FindFirstNodeMatchType(const std::string &name) const;
/*lint -e504*/
// AddNode with NodePtr
NodePtr AddNode(NodePtr node);
NodePtr AddNode(OpDescPtr op);
NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize
NodePtr AddNodeFront(NodePtr node);
NodePtr AddNodeFront(const OpDescPtr &op);
NodePtr AddInputNode(NodePtr node);
NodePtr AddOutputNode(NodePtr node);
// insert node with specific pre_node
NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node);
NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node);
graphStatus RemoveNode(const NodePtr &node);
graphStatus RemoveInputNode(const NodePtr &node);
graphStatus RemoveOutputNode(const NodePtr &node);
graphStatus RemoveConstInput(const NodePtr &node);
/// Add a subgraph to this graph. The subgraph must has a parent graph and parent node,
/// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph
/// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph()
/// must equal to subgraph->GetOwnerGraph().
/// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph.
/// The subgraph's name SHOULD(not must) be the same as the parameter `name`
graphStatus AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph);
graphStatus AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph);
void RemoveSubgraph(const std::string &name);
void RemoveSubgraph(const std::shared_ptr<ComputeGraph> &subgraph);
std::shared_ptr<ComputeGraph> GetSubgraph(const std::string &name) const;
std::vector<std::shared_ptr<ComputeGraph>> GetAllSubgraphs() const;
// obsolete
std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph);
// obsolete
graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph);
///
/// @brief Update input-mapping
/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input
/// @return graphStatus
///
graphStatus UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);
///
/// @brief Update output-mapping
/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output
/// @return graphStatus
///
graphStatus UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);
graphStatus TopologicalSorting();
bool IsValid() const;
void Dump() const;
void Swap(ComputeGraph &graph);
graphStatus IsolateNode(const NodePtr &node);
graphStatus Verify();
graphStatus InferShape();
graphStatus InferOriginFormat();
graphStatus InferShapeInNeed();
graphStatus InsertEventNodes();
bool operator==(const ComputeGraph &r_compute_graph) const;
/*lint +e504*/
const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const {
return params_share_map_;
}
void SetShareParamLayer(const std::map<std::vector<std::string>, std::vector<std::string>> params_share_map) {
params_share_map_ = params_share_map;
}
void SetInputsOrder(const std::vector<std::string> &inputs_order) { inputs_order_ = inputs_order; }
void SetGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) { out_nodes_map_ = out_nodes_map; }
void AppendGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) {
for (auto &item : out_nodes_map) {
(void)out_nodes_map_.emplace(item.first, item.second);
}
}
shared_ptr<ComputeGraph> GetParentGraph();
void SetParentGraph(const shared_ptr<ComputeGraph> &parent);
shared_ptr<Node> GetParentNode();
void SetParentNode(const shared_ptr<Node> &parent);
const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; }
void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; }
ComputeGraphPtr GetOrigGraph(void) { return origGraph_; }
void SetOutputSize(uint32_t size) { output_size_ = size; }
uint32_t GetOutputSize() const { return output_size_; }
void SetInputSize(uint32_t size) { input_size_ = size; }
uint32_t GetInputSize() const { return input_size_; }
// false: known shape true: unknow shape
bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; }
void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; }
///
/// Set is need train iteration.
/// If set true, it means this graph need to be run iteration some
/// times(according variant "npu_runconfig/iterations_per_loop").
/// @param need_iteration is need iteration
///
void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; }
void SetUserDefOutput(const std::string &output_name);
const std::string GetOutput();
///
/// Get is need train iteration.
/// @return is need iteration
///
bool GetNeedIteration() const { return need_iteration_; }
void SetGraphOpName(const std::map<uint32_t, std::string> &op_name_map) { op_name_map_ = op_name_map; }
const std::map<uint32_t, std::string> &GetGraphOpName() const { return op_name_map_; }
const std::map<OperatorImplPtr, NodePtr> &GetAllNodesInfo() const;
void SetAllNodesInfo(const std::map<OperatorImplPtr, NodePtr> &nodes) { all_nodes_infos_ = nodes; }
void SetGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) {
output_nodes_info_ = out_nodes_info;
}
void AppendGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) {
output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end());
}
const std::vector<std::pair<NodePtr, int32_t>> &GetGraphOutNodesInfo() const { return output_nodes_info_; }
void SetGraphTargetNodesInfo(const std::vector<NodePtr> &target_nodes_info) {
target_nodes_info_ = target_nodes_info;
}
const std::vector<NodePtr> &GetGraphTargetNodesInfo() const { return target_nodes_info_; }
void SetSessionID(uint64_t session_id) { session_id_ = session_id; }
uint64_t GetSessionID() const { return session_id_; }
void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; }
uint32_t GetGraphID() const { return graph_id_; }
void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; }
ge::Format GetDataFormat() const { return data_format_; }
bool IsSummaryGraph() const { return is_summary_graph_; }
void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; }
// Graph Before BFE
ComputeGraphPtr origGraph_;
protected:
ProtoAttrMapHelper MutableAttrMap() override;
ConstProtoAttrMapHelper GetAttrMap() const override;
private:
graphStatus DFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::vector<NodePtr> &stack);
graphStatus BFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::deque<NodePtr> &stack);
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::map<string, NodePtr> &breadth_node_map);
graphStatus TopologicalSortingGraph();
graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum);
Vistor<NodePtr> AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const;
size_t GetInEdgeSize(const NodePtr &node);
size_t GetOutEdgeSize(const NodePtr &node);
graphStatus RemoveExtraOutEdge(const NodePtr &node);
bool GraphMembersAreEqual(const ComputeGraph &r_graph) const;
bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const;
bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector,
const std::vector<NodePtr> &l_node_ptr_vector) const;
void SetNodesOwner();
friend class ModelSerializeImp;
friend class GraphDebugImp;
friend class OnnxUtils;
std::string name_;
uint32_t graph_id_ = 0;
ProtoAttrMapHelper attrs_;
std::vector<NodePtr> nodes_;
std::map<OperatorImplPtr, NodePtr> all_nodes_infos_;
std::vector<NodePtr> target_nodes_info_;
std::vector<NodePtr> input_nodes_;
std::vector<std::string> inputs_order_;
uint32_t input_size_ = 1;
std::map<std::string, std::vector<int32_t>> out_nodes_map_;
uint32_t output_size_ = 1;
std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_;
std::vector<std::shared_ptr<ComputeGraph>> sub_graph_;
std::map<std::string, std::shared_ptr<ComputeGraph>> names_to_subgraph_;
std::weak_ptr<ComputeGraph> parent_graph_;
std::weak_ptr<Node> parent_node_;
// the members followed should not in the ComputeGraph class
bool is_valid_flag_;
bool is_summary_graph_ = false;
// Indicates whether it is need iteration
bool need_iteration_ = false;
std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_;
// TaskIdx -> op_name Map
std::map<uint32_t, std::string> op_name_map_;
uint64_t session_id_ = 0;
ge::Format data_format_ = ge::FORMAT_ND;
// unknown graph indicator, default is false, mean known shape
bool is_unknown_shape_graph_ = false;
};
} // namespace ge
#endif // INC_GRAPH_COMPUTE_GRAPH_H_