/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ #define GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ #include "inc/graph_pass.h" namespace ge { /* Variable Initialize Flow, take as FrameworkOp +-----------+ | Merge | +-----------+ / \ 0/ \x / \ +-----------+ +-----------+ | Switch | | Switch | +-----------+ +-----------+ | |F T| | 0| | | x| | | | | | +-----------------------+ | | | IsVariableInitialized | | | +-----------------------+ | | | | | | | | | | +-----------+ +-----------+ | Const | | VariableV2| +-----------+ +-----------+ Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input +-----------+ / | task2 | \ T/ +-----------+ \ +-----------+ +-----------+ / \ +-----------+ +-----------+ | task1 | --> | Switch | | task4 | --> | noop | +-----------+ +-----------+ \ / +-----------+ +-----------+ F\ +-----------+ / \ | task3 | / +-----------+ cond(x < y, lambda: add(x, z), lambda: square(y)) +-----------+ +-----------+ | Merge | +------------|StreamMerge|----------+ +-----------+ | +-----------+ | / \ | | | / \ |c | |c / \ +----------+ ----------- +----------+ +-----------+ +-----------+ | Active_f | / \ | Active_t | | Square | | Add | +----------+ / \ +----------+ +-----------+ +-----------+ \ / \ / / / \ \c / \ /c y/ x/ \z +-----------+ +-----------+ / / \ | Square | | Add | +-----------+ +-----------+ +-----------+ +-----------+ +-----------+ | Switch | | Switch | | Switch | ====> / | / | \ +-----------+ +-----------+ +-----------+ / | / | \ y| |F T| |x T| |z +--------+ | +--------+ | +--------+ | | | | | | | y/read | | | x/read | | | z/read | | +-----------+ | | | +--------+ | +--------+ | +--------+ | | Less |-------------------+ | |c |c | +-----------+ | | +----------------+ +----------------+ | | | | StreamSwitch_f | | StreamSwitch_t | | | | +----------------+ +----------------+ +-----------+ +-----------+ +-----------+ | | | y/read | | x/read | | z/read | | +-----------+ | +-----------+ +-----------+ +-----------+ +-----| Less |----+ +-----------+ */ class SwitchToStreamSwitchPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); /// /// @brief Clear Status, used for subgraph pass /// @return /// Status ClearStatus() override; private: /// /// @brief Check cyclic dependence /// @param [in] graph /// @return Status /// Status CheckCycleDependence(const ComputeGraphPtr &graph); /// /// @brief Mark cyclic dependence /// @param [in] graph /// @param [in] cond_switch_map /// @return void /// void MarkCycleDependence(const std::unordered_map> &cond_switch_map); /// /// @brief Replace Switch Op /// @param [in] graph /// @param [in] switch_node /// @return Status /// Status ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node); /// /// @brief Bypass Switch Node /// @param [in] switch_node /// @param [out] peer_data_anchor /// @param [out] peer_cond_anchor /// @return Status /// Status BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); /// /// @brief Find Switch cond input /// @param [out] peer_cond_anchor /// @return Status /// Status FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor); /// /// @brief Create StreamSwitch Node /// @param [in] graph /// @param [in] switch_node /// @param [in] suffix /// @param [in] peer_cond_anchor /// @return ge::NodePtr /// NodePtr CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, const OutDataAnchorPtr &peer_cond_anchor); /// /// @brief Mark Switch Branch /// @param [in] peer_cond_anchor /// @param [in] stream_switch /// @param [in] true_branch_flag /// @return Status /// Status MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch_node, bool true_branch_flag); /// /// @brief Get group_id for switch_node /// @param [in] node /// @return group_id /// int64_t GetGroupId(const NodePtr &node); /// /// @brief Combine switch nodes link to same cond /// @param [in] graph /// @return Status /// Status CombineSwitchNode(const ComputeGraphPtr &graph); /// /// @brief Create cast node /// @param [in] graph /// @param [in] peer_cond_anchor /// @return NodePtr /// NodePtr CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor); /// /// @brief Create Active Op /// @param [in] graph /// @param [in] cond_node /// @return ge::NodePtr /// NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node); /// /// @brief Add const node as switch input1 /// @param [in] graph /// @param [in] stream_switch /// @return Status /// Status AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch_node); /// /// @brief Modify in ctl edge for switch_node /// @param [in] switch_node /// @param [in] cast_node /// @param [in] same_cond_switch /// @return Status /// Status ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, const std::set &same_cond_switch); /// /// @brief Modify out ctl edge for switch_node /// @param [in] switch_node /// @param [in] stream_switch /// @param [in] active_node /// @return Status /// Status ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, const NodePtr &active_node); /// /// @brief Check duplicate node_name /// @param [in] node_name /// @return std::string /// std::string CheckDuplicateName(const std::string &node_name); /// /// @brief Move Control Edges /// @param [in] old_node /// @param [in] new_node /// @return void /// void MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node); std::vector switch_nodes_; std::unordered_map> switch_cyclic_map_; std::set bypass_nodes_; std::vector stream_switch_nodes_; std::unordered_map>>> cond_node_map_; std::unordered_map> switch_node_map_; std::map node_num_map_; }; } // namespace ge #endif // GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_