You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/net_output_pass.h

226 lines
7.6 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_
#define GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_
#include <map>
#include <set>
#include <utility>
#include <vector>
#include "graph/types.h"
#include "inc/graph_pass.h"
namespace ge {
struct RetvalInfo {
NodePtr output_node;
int32_t node_output_index;
int parent_node_index;
};
class NetOutputPass : public GraphPass {
public:
///
/// Entry of the NetOutputPass optimizer
/// @param [in] graph: Input ComputeGraph
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status Run(ge::ComputeGraphPtr graph) override;
private:
///
/// The graph of identifies the network output with
/// the _Retval node, we determine if the input node is a network output here.
/// @param [in] node: Input node
/// @param [in/out] retval_node_index_map: Obtained output node <NodePtr, index> pair
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status GetRetvalOutputInfo(const ge::NodePtr &node, std::map<int32_t, RetvalInfo> &retval_node_index_map);
///
/// Get the output node of the graph
/// @param [in] graph: Input ComputeGraph
/// @param [in/out] output_nodes_info: Obtained output node <NodePtr, index> pair
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status GetOutputNode(const ge::ComputeGraphPtr &graph, std::vector<RetvalInfo> &output_nodes_info);
///
/// Get the output node of the graph
/// @param [in] graph: Input ComputeGraph
/// @param [in/out] net_output_desc: output netoutput node <NodePtr, index> pair
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status CreateNetOutputNode(OpDescPtr &net_output_desc, const ge::ComputeGraphPtr &graph);
///
/// Check if the network output node is legal
/// @param [in] graph: Input ComputeGraph
/// @param [in] outputs: Output node information of graph
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status CheckOutputNodeInfo(const ComputeGraphPtr &graph, const std::vector<RetvalInfo> &outputs);
///
/// Set input and output for the NetOutput node
/// @param [in] graph: Input ComputeGraph
/// @param [in] net_output_desc: OpDesc of the NetOutput node
/// @param [in] output_nodes_info: RetvalInfos of the NetOutput
/// @return void
/// @author
///
void AddInOutForNetOutputOp(const ComputeGraphPtr &graph, OpDescPtr &net_output_desc,
vector<RetvalInfo> &output_nodes_info);
///
/// Delete unwanted _Retval/Save/Summary nodes
/// @param [in] graph: Input ComputeGraph
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status RemoveUnusedNode(const ge::ComputeGraphPtr &graph);
///
/// Update the output/input tensor description of the NetOutput node
/// @param [in] net_output: The netOutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status UpdateNetOutputDesc(const ge::NodePtr &net_output);
///
/// Add ctrl edge from target node to netoutput node
/// @param [in] net_output: The netOutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status AddCtrlEdgeForTargets(const ge::NodePtr &net_out_node);
///
/// Remove invalid node and duplicated node of user set targets
/// @param [in] : compute graph
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
void SaveAndRemoveTargets(const ge::ComputeGraphPtr &graph);
///
/// Add edges for the NetOutput node
/// @param [in] graph: Input ComputeGraph
/// @param [in] net_out_node: The netOutput node
/// @param [in] output_nodes_info: Output node <NodePtr, index> pair
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status AddEdgesForNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node,
const std::vector<RetvalInfo> &output_nodes_info);
///
/// Add ctrl edges for leaf node
/// @param [in] graph: Input ComputeGraph
/// @param [in] net_out_node: The netOutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status AddCtrlEdgesBetweenLeafAndNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node);
///
/// Unlink all connections between target nodes and netoutput node
/// @param [in] graph: ComputeGraph
/// @param [in] net_out_node: The netOutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status UnLink(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node);
///
/// Unlink data connections between target nodes and netoutput node
/// @param [in] graph: ComputeGraph
/// @param [in] net_out_node: The netOutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status UnLinkDataAnchorOfNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node);
///
/// Unlink control connections between target nodes and netoutput node
/// @param [in] graph: ComputeGraph
/// @param [in] net_out_node: The netOutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status UnLinkControlAnchorOfNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node);
///
/// if user have set netoutput node , do relative process
/// @param [in] graph: ComputeGraph
/// @param [in] net_out_node: The netOutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status ProcessWithNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &output_node);
///
/// check node wether exist in user-set output nodes
/// @param [in] graph: ComputeGraph
/// @param [in] net_out_node: The netOutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
bool CheckNodeIsInOutputNodes(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node);
///
/// Add netoutput node to graph with output node infos
/// @param [in] graph: ComputeGraph
/// @param [in] output_node: shared_ptr to netoutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph, NodePtr &output_node);
///
/// Add user_def_dtype & format for netoutput node
/// @param [in] output_node: The netOutput node
/// @return SUCCESS: Execution succeed
/// @return OTHERS: Execution failed
/// @author
///
Status SetUserDefDTypeAndFormatFromAtcParams(const ge::NodePtr &output_node);
bool is_include_special_node_ = false;
std::set<NodePtr> targets_;
friend class ReUpdateNetOutputPass;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_