/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_ #define GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_ #include #include #include #include #include "graph/types.h" #include "inc/graph_pass.h" namespace ge { struct RetvalInfo { NodePtr output_node; int32_t node_output_index; int parent_node_index; }; class NetOutputPass : public GraphPass { public: /// /// Entry of the NetOutputPass optimizer /// @param [in] graph: Input ComputeGraph /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status Run(ge::ComputeGraphPtr graph) override; private: /// /// The graph of identifies the network output with /// the _Retval node, we determine if the input node is a network output here. /// @param [in] node: Input node /// @param [in/out] retval_node_index_map: Obtained output node pair /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status GetRetvalOutputInfo(const ge::NodePtr &node, std::map &retval_node_index_map); /// /// Get the output node of the graph /// @param [in] graph: Input ComputeGraph /// @param [in/out] output_nodes_info: Obtained output node pair /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status GetOutputNode(const ge::ComputeGraphPtr &graph, std::vector &output_nodes_info); /// /// Get the output node of the graph /// @param [in] graph: Input ComputeGraph /// @param [in/out] net_output_desc: output netoutput node pair /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status CreateNetOutputNode(OpDescPtr &net_output_desc, const ge::ComputeGraphPtr &graph); /// /// Check if the network output node is legal /// @param [in] graph: Input ComputeGraph /// @param [in] outputs: Output node information of graph /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status CheckOutputNodeInfo(const ComputeGraphPtr &graph, const std::vector &outputs); /// /// Set input and output for the NetOutput node /// @param [in] graph: Input ComputeGraph /// @param [in] net_output_desc: OpDesc of the NetOutput node /// @param [in] output_nodes_info: RetvalInfos of the NetOutput /// @return void /// @author /// void AddInOutForNetOutputOp(const ComputeGraphPtr &graph, OpDescPtr &net_output_desc, vector &output_nodes_info); /// /// Delete unwanted _Retval/Save/Summary nodes /// @param [in] graph: Input ComputeGraph /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status RemoveUnusedNode(const ge::ComputeGraphPtr &graph); /// /// Update the output/input tensor description of the NetOutput node /// @param [in] net_output: The netOutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status UpdateNetOutputDesc(const ge::NodePtr &net_output); /// /// Add ctrl edge from target node to netoutput node /// @param [in] net_output: The netOutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status AddCtrlEdgeForTargets(const ge::NodePtr &net_out_node); /// /// Remove invalid node and duplicated node of user set targets /// @param [in] : compute graph /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// void SaveAndRemoveTargets(const ge::ComputeGraphPtr &graph); /// /// Add edges for the NetOutput node /// @param [in] graph: Input ComputeGraph /// @param [in] net_out_node: The netOutput node /// @param [in] output_nodes_info: Output node pair /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status AddEdgesForNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node, const std::vector &output_nodes_info); /// /// Add ctrl edges for leaf node /// @param [in] graph: Input ComputeGraph /// @param [in] net_out_node: The netOutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status AddCtrlEdgesBetweenLeafAndNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node); /// /// Unlink all connections between target nodes and netoutput node /// @param [in] graph: ComputeGraph /// @param [in] net_out_node: The netOutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status UnLink(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node); /// /// Unlink data connections between target nodes and netoutput node /// @param [in] graph: ComputeGraph /// @param [in] net_out_node: The netOutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status UnLinkDataAnchorOfNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node); /// /// Unlink control connections between target nodes and netoutput node /// @param [in] graph: ComputeGraph /// @param [in] net_out_node: The netOutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status UnLinkControlAnchorOfNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node); /// /// if user have set netoutput node , do relative process /// @param [in] graph: ComputeGraph /// @param [in] net_out_node: The netOutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status ProcessWithNetoutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &output_node); /// /// check node wether exist in user-set output nodes /// @param [in] graph: ComputeGraph /// @param [in] net_out_node: The netOutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// bool CheckNodeIsInOutputNodes(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node); /// /// Add netoutput node to graph with output node infos /// @param [in] graph: ComputeGraph /// @param [in] output_node: shared_ptr to netoutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph, NodePtr &output_node); /// /// Add user_def_dtype & format for netoutput node /// @param [in] output_node: The netOutput node /// @return SUCCESS: Execution succeed /// @return OTHERS: Execution failed /// @author /// Status SetUserDefDTypeAndFormatFromAtcParams(const ge::NodePtr &output_node); bool is_include_special_node_ = false; std::set targets_; friend class ReUpdateNetOutputPass; bool is_user_define_ouput_nodes = false; }; } // namespace ge #endif // GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_