/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_GRAPH_PASSES_FOR_PASS_H #define GE_GRAPH_PASSES_FOR_PASS_H #include "graph/passes/base_pass.h" struct ForInfo { ForInfo() : for_node(nullptr), start(nullptr), limit(nullptr), delta(nullptr), for_body(nullptr) {} ge::NodePtr for_node; ge::OutDataAnchorPtr start; ge::OutDataAnchorPtr limit; ge::OutDataAnchorPtr delta; std::string body_name; ge::ComputeGraphPtr for_body; std::vector data_inputs; std::vector> data_outputs; std::vector ctrl_inputs; std::vector ctrl_outputs; }; struct WhileInfo { WhileInfo() : while_node(nullptr), sub_graph_node(nullptr), i(nullptr), abs_delta(nullptr), range(nullptr), start(nullptr), delta(nullptr), for_body(nullptr), while_cond(nullptr), while_body(nullptr) {} ge::NodePtr while_node; ge::NodePtr sub_graph_node; ge::OutDataAnchorPtr i; ge::OutDataAnchorPtr abs_delta; ge::OutDataAnchorPtr range; ge::OutDataAnchorPtr start; ge::OutDataAnchorPtr delta; std::string for_body_name; ge::ComputeGraphPtr for_body; ge::ComputeGraphPtr while_cond; ge::ComputeGraphPtr while_body; std::vector data_inputs; std::vector> data_outputs; std::vector ctrl_inputs; std::vector ctrl_outputs; }; namespace ge { class ForPass : public BaseNodePass { public: Status Run(NodePtr &node) override; private: /// /// @brief Build for_info /// @param [in] root_graph /// @param [in] node /// @param [out] for_info /// @return Status /// static Status BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info); /// /// @brief Transfer while_info from for_info /// @param [in] graph /// @param [in] for_info /// @param [out] while_info /// @return Status /// Status TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info); /// /// @brief Build cond_graph for while_node /// @param [in&out] while_info /// @return ComputeGraphPtr /// static ComputeGraphPtr BuildCondGraph(WhileInfo &while_info); /// /// @brief Build body_graph for while_node /// @param [in&out] while_info /// @return ComputeGraphPtr /// static ComputeGraphPtr BuildBodyGraph(WhileInfo &while_info); /// /// @brief Update InputMapping for for-body-graph /// @param [in] while_info /// @return Status /// static Status UpdateForBodyInputMapping(const WhileInfo &while_info); /// /// @brief Find input with index for For node /// @param [in] node /// @param [in] index /// @return OutDataAnchorPtr /// static OutDataAnchorPtr FindInputWithIndex(const NodePtr &node, uint32_t index); /// /// @brief Find inputs / outputs for for node /// @param [in] node /// @param [out] data_inputs /// @param [out] data_outputs /// @param [out] ctrl_inputs /// @param [out] ctrl_outputs /// @return Status /// static Status FindInputsAndOutputs(const NodePtr &node, std::vector &data_inputs, std::vector> &data_outputs, std::vector &ctrl_inputs, std::vector &ctrl_outputs); /// /// @brief Create const op_desc /// @param [in] name /// @param [in] value /// @return OpDescPtr /// static OpDescPtr CreateConstDesc(const std::string &name, int32_t value); /// /// @brief Create loop input /// @param [in] graph /// @param [in] for_info /// @param [out] range_input /// @param [out] abs_delta_input /// @return Status /// Status CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info, OutDataAnchorPtr &range_input, OutDataAnchorPtr &abs_delta_input); /// /// @brief Create op_desc /// @param [in] name /// @param [in] type /// @param [in] io_equal_flag /// @return OpDescPtr /// static OpDescPtr CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag); /// /// @brief Build while-info /// @param [in] for_info /// @param [in] i_input /// @param [in] range_input /// @param [in] abs_delta_input /// @param [out] while_info /// @return void /// static void BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input, const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input, WhileInfo &while_info); /// /// @brief Insert while_node /// @param [in] graph /// @param [in] name /// @param [in] while_info /// @return Status /// Status InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info); /// /// @brief Build while link-edge /// @param [in] while_info /// @return Status /// static Status BuildWhileLink(const WhileInfo &while_info); /// /// @brief Create op_desc for subgraph node /// @param [in] name /// @param [in] input_num /// @param [in] output_num /// @return OpDescPtr /// static OpDescPtr CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num); }; } // namespace ge #endif //GE_GRAPH_PASSES_FOR_PASS_H