You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/inc/graph/compute_graph.h

243 lines
8.8 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_GRAPH_COMPUTE_GRAPH_H_
#define INC_GRAPH_COMPUTE_GRAPH_H_
#include <deque>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "detail/attributes_holder.h"
#include "graph/anchor.h"
#include "graph/node.h"
#include "graph/op_desc.h"
#include "graph/range_vistor.h"
namespace ge {
class Node;
using NodePtr = std::shared_ptr<Node>;
class Edge;
using EdgePtr = std::shared_ptr<Edge>;
class InDataAnchor;
using InDataAnchorPtr = std::shared_ptr<InDataAnchor>;
class OutDataAnchor;
using OutDataAnchorPtr = std::shared_ptr<OutDataAnchor>;
class ControlAnchor;
using ControlAnchorPtr = std::shared_ptr<ControlAnchor>;
class InControlAnchor;
using InControlAnchorPtr = std::shared_ptr<InControlAnchor>;
class OutControlAnchor;
using OutControlAnchorPtr = std::shared_ptr<OutControlAnchor>;
class GeAttrValue;
using AttrValuePtr = std::shared_ptr<GeAttrValue>;
using ConstComputeGraph = const ComputeGraph;
class OperatorImpl;
using OperatorImplPtr = std::shared_ptr<OperatorImpl>;
class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public AttrHolder {
friend class GraphUtils;
public:
template <class T>
using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>;
explicit ComputeGraph(const std::string &name);
virtual ~ComputeGraph();
std::string GetName() const;
void SetName(const std::string &name);
using AttrHolder::DelAttr;
using AttrHolder::GetAttr;
using AttrHolder::HasAttr;
using AttrHolder::SetAttr;
size_t GetAllNodesSize() const;
Vistor<NodePtr> GetAllNodes() const;
size_t GetDirectNodesSize() const;
Vistor<NodePtr> GetDirectNode() const;
Vistor<NodePtr> GetInputNodes() const;
Vistor<NodePtr> GetOutputNodes() const;
NodePtr FindNode(const std::string &name) const;
// Add node
NodePtr AddNode(NodePtr node);
NodePtr AddNode(OpDescPtr op);
NodePtr AddNodeFront(NodePtr node);
NodePtr AddNodeFront(const OpDescPtr &op);
NodePtr AddInputNode(NodePtr node);
NodePtr AddOutputNode(NodePtr node);
graphStatus RemoveNode(const NodePtr &node);
graphStatus RemoveInputNode(const NodePtr &node);
graphStatus RemoveOutputNode(const NodePtr &node);
graphStatus RemoveConstInput(const NodePtr &node);
std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph);
graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph);
graphStatus TopologicalSorting();
bool IsValid() const;
void Dump() const;
graphStatus IsolateNode(const NodePtr &node);
graphStatus Verify();
graphStatus InferShape();
graphStatus InferOriginFormat();
graphStatus InferShapeInNeed();
graphStatus InsertEventNodes();
bool operator==(const ComputeGraph &r_compute_graph) const;
const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const {
return params_share_map_;
}
void SetShareParamLayer(const std::map<std::vector<std::string>, std::vector<std::string>> params_share_map) {
params_share_map_ = params_share_map;
}
void SetInputsOrder(const std::vector<std::string> &inputs_order) { inputs_order_ = inputs_order; }
void SetGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) { out_nodes_map_ = out_nodes_map; }
void AppendGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) {
for (auto &item : out_nodes_map) {
(void)out_nodes_map_.emplace(item.first, item.second);
}
}
const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; }
void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; }
ComputeGraphPtr GetOrigGraph(void) { return origGraph_; }
void SetOutputSize(uint32_t size) { output_size_ = size; }
uint32_t GetOutputSize() const { return output_size_; }
void SetInputSize(uint32_t size) { input_size_ = size; }
uint32_t GetInputSize() const { return input_size_; }
///
/// Set iteration needed.
/// If set is true, it means this graph need run iteration some
/// times(according variant "npu_runconfig/iterations_per_loop").
/// @param need_iteration is need iteration
///
void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; }
void SetUserDefOutput(const std::string &output_name);
const std::string GetOutput();
///
/// Get need_iteration.
/// @return is need iteration
///
bool GetNeedIteration() const { return need_iteration_; }
void SetGraphOpName(const std::map<uint32_t, std::string> &op_name_map) { op_name_map_ = op_name_map; }
const std::map<uint32_t, std::string> &GetGraphOpName() const { return op_name_map_; }
const std::map<OperatorImplPtr, NodePtr> &GetAllNodesInfo() const;
void SetAllNodesInfo(const std::map<OperatorImplPtr, NodePtr> &nodes) { all_nodes_infos_ = nodes; }
void SetGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) {
output_nodes_info_ = out_nodes_info;
}
void AppendGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) {
output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end());
}
const std::vector<std::pair<NodePtr, int32_t>> &GetGraphOutNodesInfo() const { return output_nodes_info_; }
void SetGraphTargetNodesInfo(const std::vector<NodePtr> &target_nodes_info) {
target_nodes_info_ = target_nodes_info;
}
const std::vector<NodePtr> &GetGraphTargetNodesInfo() const { return target_nodes_info_; }
void SetSessionID(uint64_t session_id) { session_id_ = session_id; }
uint64_t GetSessionID() const { return session_id_; }
void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; }
uint32_t GetGraphID() const { return graph_id_; }
void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; }
ge::Format GetDataFormat() const { return data_format_; }
bool IsSummaryGraph() const { return is_summary_graph_; }
void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; }
// Graph Before BFE
ComputeGraphPtr origGraph_;
protected:
ProtoAttrMapHelper MutableAttrMap() override;
ConstProtoAttrMapHelper GetAttrMap() const override;
private:
graphStatus DFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::vector<NodePtr> &stack);
graphStatus BFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::deque<NodePtr> &stack);
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::map<string, NodePtr> &breadth_node_map);
graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum);
size_t GetInEdgeSize(const NodePtr &node);
size_t GetOutEdgeSize(const NodePtr &node);
graphStatus RemoveExtraOutEdge(const NodePtr &node);
bool GraphMembersAreEqual(const ComputeGraph &r_graph) const;
bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const;
bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector,
const std::vector<NodePtr> &l_node_ptr_vector) const;
ProtoAttrMapHelper attrs_;
friend class ModelSerializeImp;
friend class GraphDebugImp;
friend class OnnxUtils;
std::vector<NodePtr> nodes_;
std::vector<NodePtr> input_nodes_;
std::vector<std::shared_ptr<ComputeGraph>> sub_graph_;
std::string name_;
bool is_valid_flag_;
bool is_summary_graph_ = false;
// Indicates whether it is need iteration
bool need_iteration_ = false;
std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_;
std::map<std::string, std::vector<int32_t>> out_nodes_map_;
// TaskIdx -> op_name Map
std::map<uint32_t, std::string> op_name_map_;
std::vector<std::string> inputs_order_;
uint32_t output_size_ = 1;
uint32_t input_size_ = 1;
std::map<OperatorImplPtr, NodePtr> all_nodes_infos_;
std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_;
std::vector<NodePtr> target_nodes_info_;
uint64_t session_id_ = 0;
uint32_t graph_id_ = 0;
ge::Format data_format_ = ge::FORMAT_ND;
};
} // namespace ge
#endif // INC_GRAPH_COMPUTE_GRAPH_H_