You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
226 lines
7.6 KiB
226 lines
7.6 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_
|
|
#define GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_
|
|
|
|
#include <map>
|
|
#include <set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "graph/types.h"
|
|
#include "inc/graph_pass.h"
|
|
|
|
namespace ge {
|
|
struct RetvalInfo {
|
|
NodePtr output_node;
|
|
int32_t node_output_index;
|
|
int parent_node_index;
|
|
};
|
|
|
|
class NetOutputPass : public GraphPass {
|
|
public:
|
|
///
|
|
/// Entry of the NetOutputPass optimizer
|
|
/// @param [in] graph: Input ComputeGraph
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status Run(ge::ComputeGraphPtr graph) override;
|
|
|
|
private:
|
|
///
|
|
/// The graph of identifies the network output with
|
|
/// the _Retval node, we determine if the input node is a network output here.
|
|
/// @param [in] node: Input node
|
|
/// @param [in/out] retval_node_index_map: Obtained output node <NodePtr, index> pair
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status GetRetvalOutputInfo(const ge::NodePtr &node, std::map<int32_t, RetvalInfo> &retval_node_index_map);
|
|
|
|
///
|
|
/// Get the output node of the graph
|
|
/// @param [in] graph: Input ComputeGraph
|
|
/// @param [in/out] output_nodes_info: Obtained output node <NodePtr, index> pair
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status GetOutputNode(const ge::ComputeGraphPtr &graph, std::vector<RetvalInfo> &output_nodes_info);
|
|
|
|
///
|
|
/// Get the output node of the graph
|
|
/// @param [in] graph: Input ComputeGraph
|
|
/// @param [in/out] net_output_desc: output netoutput node <NodePtr, index> pair
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status CreateNetOutputNode(OpDescPtr &net_output_desc, const ge::ComputeGraphPtr &graph);
|
|
|
|
///
|
|
/// Check if the network output node is legal
|
|
/// @param [in] graph: Input ComputeGraph
|
|
/// @param [in] outputs: Output node information of graph
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status CheckOutputNodeInfo(const ComputeGraphPtr &graph, const std::vector<RetvalInfo> &outputs);
|
|
|
|
///
|
|
/// Set input and output for the NetOutput node
|
|
/// @param [in] graph: Input ComputeGraph
|
|
/// @param [in] net_output_desc: OpDesc of the NetOutput node
|
|
/// @param [in] output_nodes_info: RetvalInfos of the NetOutput
|
|
/// @return void
|
|
/// @author
|
|
///
|
|
void AddInOutForNetOutputOp(const ComputeGraphPtr &graph, OpDescPtr &net_output_desc,
|
|
vector<RetvalInfo> &output_nodes_info);
|
|
|
|
///
|
|
/// Delete unwanted _Retval/Save/Summary nodes
|
|
/// @param [in] graph: Input ComputeGraph
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status RemoveUnusedNode(const ge::ComputeGraphPtr &graph);
|
|
|
|
///
|
|
/// Update the output/input tensor description of the NetOutput node
|
|
/// @param [in] net_output: The netOutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status UpdateNetOutputDesc(const ge::NodePtr &net_output);
|
|
|
|
///
|
|
/// Add ctrl edge from target node to netoutput node
|
|
/// @param [in] net_output: The netOutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status AddCtrlEdgeForTargets(const ge::NodePtr &net_out_node);
|
|
|
|
///
|
|
/// Remove invalid node and duplicated node of user set targets
|
|
/// @param [in] : compute graph
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
void SaveAndRemoveTargets(const ge::ComputeGraphPtr &graph);
|
|
|
|
///
|
|
/// Add edges for the NetOutput node
|
|
/// @param [in] graph: Input ComputeGraph
|
|
/// @param [in] net_out_node: The netOutput node
|
|
/// @param [in] output_nodes_info: Output node <NodePtr, index> pair
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status AddEdgesForNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node,
|
|
const std::vector<RetvalInfo> &output_nodes_info);
|
|
///
|
|
/// Add ctrl edges for leaf node
|
|
/// @param [in] graph: Input ComputeGraph
|
|
/// @param [in] net_out_node: The netOutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status AddCtrlEdgesBetweenLeafAndNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node);
|
|
///
|
|
/// Unlink all connections between target nodes and netoutput node
|
|
/// @param [in] graph: ComputeGraph
|
|
/// @param [in] net_out_node: The netOutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status UnLink(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node);
|
|
///
|
|
/// Unlink data connections between target nodes and netoutput node
|
|
/// @param [in] graph: ComputeGraph
|
|
/// @param [in] net_out_node: The netOutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status UnLinkDataAnchorOfNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node);
|
|
///
|
|
/// Unlink control connections between target nodes and netoutput node
|
|
/// @param [in] graph: ComputeGraph
|
|
/// @param [in] net_out_node: The netOutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status UnLinkControlAnchorOfNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node);
|
|
///
|
|
/// if user have set netoutput node , do relative process
|
|
/// @param [in] graph: ComputeGraph
|
|
/// @param [in] net_out_node: The netOutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status ProcessWithNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &output_node);
|
|
///
|
|
/// check node wether exist in user-set output nodes
|
|
/// @param [in] graph: ComputeGraph
|
|
/// @param [in] net_out_node: The netOutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
bool CheckNodeIsInOutputNodes(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node);
|
|
|
|
///
|
|
/// Add netoutput node to graph with output node infos
|
|
/// @param [in] graph: ComputeGraph
|
|
/// @param [in] output_node: shared_ptr to netoutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph, NodePtr &output_node);
|
|
|
|
///
|
|
/// Add user_def_dtype & format for netoutput node
|
|
/// @param [in] output_node: The netOutput node
|
|
/// @return SUCCESS: Execution succeed
|
|
/// @return OTHERS: Execution failed
|
|
/// @author
|
|
///
|
|
Status SetUserDefDTypeAndFormatFromAtcParams(const ge::NodePtr &output_node);
|
|
|
|
bool is_include_special_node_ = false;
|
|
std::set<NodePtr> targets_;
|
|
friend class ReUpdateNetOutputPass;
|
|
};
|
|
} // namespace ge
|
|
#endif // GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_
|