You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
243 lines
9.5 KiB
243 lines
9.5 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_
|
|
#define GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_
|
|
|
|
#include "inc/graph_pass.h"
|
|
|
|
namespace ge {
|
|
/* Variable Initialize Flow, take as FrameworkOp
|
|
+-----------+
|
|
| Merge |
|
|
+-----------+
|
|
/ \
|
|
0/ \x
|
|
/ \
|
|
+-----------+ +-----------+
|
|
| Switch | | Switch |
|
|
+-----------+ +-----------+
|
|
| |F T| |
|
|
0| | | x|
|
|
| | | |
|
|
| +-----------------------+ |
|
|
| | IsVariableInitialized | |
|
|
| +-----------------------+ |
|
|
| | |
|
|
| | |
|
|
| | |
|
|
+-----------+ +-----------+
|
|
| Const | | VariableV2|
|
|
+-----------+ +-----------+
|
|
*/
|
|
|
|
/* Switch branch op optimize, Switches in same case merge to one StreamSwitch, update following nodes' input
|
|
|
|
+-----------+
|
|
/ | task2 | \
|
|
T/ +-----------+ \
|
|
+-----------+ +-----------+ / \ +-----------+ +-----------+
|
|
| task1 | --> | Switch | | task4 | --> | noop |
|
|
+-----------+ +-----------+ \ / +-----------+ +-----------+
|
|
F\ +-----------+ /
|
|
\ | task3 | /
|
|
+-----------+
|
|
|
|
cond(x < y, lambda: add(x, z), lambda: square(y))
|
|
|
|
+-----------+ +-----------+
|
|
| Merge | +------------|StreamMerge|----------+
|
|
+-----------+ | +-----------+ |
|
|
/ \ | | |
|
|
/ \ |c | |c
|
|
/ \ +----------+ ----------- +----------+
|
|
+-----------+ +-----------+ | Active_f | / \ | Active_t |
|
|
| Square | | Add | +----------+ / \ +----------+
|
|
+-----------+ +-----------+ \ / \ /
|
|
/ / \ \c / \ /c
|
|
y/ x/ \z +-----------+ +-----------+
|
|
/ / \ | Square | | Add |
|
|
+-----------+ +-----------+ +-----------+ +-----------+ +-----------+
|
|
| Switch | | Switch | | Switch | ====> / | / | \
|
|
+-----------+ +-----------+ +-----------+ / | / | \
|
|
y| |F T| |x T| |z +--------+ | +--------+ | +--------+
|
|
| | | | | | | y/read | | | x/read | | | z/read |
|
|
| +-----------+ | | | +--------+ | +--------+ | +--------+
|
|
| | Less |-------------------+ | |c |c
|
|
| +-----------+ | | +----------------+ +----------------+
|
|
| | | | StreamSwitch_f | | StreamSwitch_t |
|
|
| | | +----------------+ +----------------+
|
|
+-----------+ +-----------+ +-----------+ | |
|
|
| y/read | | x/read | | z/read | | +-----------+ |
|
|
+-----------+ +-----------+ +-----------+ +-----| Less |----+
|
|
+-----------+
|
|
*/
|
|
class SwitchToStreamSwitchPass : public GraphPass {
|
|
public:
|
|
Status Run(ComputeGraphPtr graph);
|
|
|
|
///
|
|
/// @brief Clear Status, used for subgraph pass
|
|
/// @return
|
|
///
|
|
Status ClearStatus() override;
|
|
|
|
private:
|
|
///
|
|
/// @brief Check cyclic dependence
|
|
/// @param [in] graph
|
|
/// @return Status
|
|
///
|
|
Status CheckCycleDependence(const ComputeGraphPtr &graph);
|
|
|
|
///
|
|
/// @brief Mark cyclic dependence
|
|
/// @param [in] graph
|
|
/// @param [in] cond_switch_map
|
|
/// @return void
|
|
///
|
|
void MarkCycleDependence(const std::unordered_map<NodePtr, std::vector<NodePtr>> &cond_switch_map);
|
|
|
|
///
|
|
/// @brief Replace Switch Op
|
|
/// @param [in] graph
|
|
/// @param [in] switch_node
|
|
/// @return Status
|
|
///
|
|
Status ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node);
|
|
|
|
///
|
|
/// @brief Bypass Switch Node
|
|
/// @param [in] switch_node
|
|
/// @param [out] peer_data_anchor
|
|
/// @param [out] peer_cond_anchor
|
|
/// @return Status
|
|
///
|
|
Status BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor,
|
|
OutDataAnchorPtr &peer_cond_anchor);
|
|
|
|
///
|
|
/// @brief Find Switch cond input
|
|
/// @param [in] pass_switch_flag
|
|
/// @param [out] peer_cond_anchor
|
|
/// @return Status
|
|
///
|
|
Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor);
|
|
|
|
///
|
|
/// @brief Create StreamSwitch Node
|
|
/// @param [in] graph
|
|
/// @param [in] switch_node
|
|
/// @param [in] suffix
|
|
/// @param [in] peer_cond_anchor
|
|
/// @return ge::NodePtr
|
|
///
|
|
NodePtr CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix,
|
|
const OutDataAnchorPtr &peer_cond_anchor);
|
|
|
|
///
|
|
/// @brief Mark Switch Branch
|
|
/// @param [in] peer_cond_anchor
|
|
/// @param [in] stream_switch
|
|
/// @param [in] true_branch_flag
|
|
/// @return Status
|
|
///
|
|
Status MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch_node,
|
|
bool true_branch_flag);
|
|
|
|
///
|
|
/// @brief Get group_id for switch_node
|
|
/// @param [in] node
|
|
/// @return group_id
|
|
///
|
|
int64_t GetGroupId(const NodePtr &node);
|
|
|
|
///
|
|
/// @brief Combine switch nodes link to same cond
|
|
/// @param [in] graph
|
|
/// @return Status
|
|
///
|
|
Status CombineSwitchNode(const ComputeGraphPtr &graph);
|
|
|
|
///
|
|
/// @brief Create cast node
|
|
/// @param [in] graph
|
|
/// @param [in] peer_cond_anchor
|
|
/// @return NodePtr
|
|
///
|
|
NodePtr CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor);
|
|
|
|
///
|
|
/// @brief Create Active Op
|
|
/// @param [in] graph
|
|
/// @param [in] cond_node
|
|
/// @return ge::NodePtr
|
|
///
|
|
NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node);
|
|
|
|
///
|
|
/// @brief Add const node as switch input1
|
|
/// @param [in] graph
|
|
/// @param [in] stream_switch
|
|
/// @return Status
|
|
///
|
|
Status AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch_node);
|
|
|
|
///
|
|
/// @brief Modify in ctl edge for switch_node
|
|
/// @param [in] switch_node
|
|
/// @param [in] cast_node
|
|
/// @param [in] same_cond_switch
|
|
/// @return Status
|
|
///
|
|
Status ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node,
|
|
const std::set<NodePtr> &same_cond_switch);
|
|
|
|
///
|
|
/// @brief Modify out ctl edge for switch_node
|
|
/// @param [in] switch_node
|
|
/// @param [in] stream_switch
|
|
/// @param [in] active_node
|
|
/// @return Status
|
|
///
|
|
Status ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, const NodePtr &active_node);
|
|
|
|
///
|
|
/// @brief Check duplicate node_name
|
|
/// @param [in] node_name
|
|
/// @return std::string
|
|
///
|
|
std::string CheckDuplicateName(const std::string &node_name);
|
|
|
|
///
|
|
/// @brief Move Control Edges
|
|
/// @param [in] old_node
|
|
/// @param [in] new_node
|
|
/// @return void
|
|
///
|
|
void MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node);
|
|
|
|
std::vector<NodePtr> switch_nodes_;
|
|
std::unordered_map<NodePtr, std::set<std::string>> switch_cyclic_map_;
|
|
std::set<NodePtr> bypass_nodes_;
|
|
std::vector<NodePtr> stream_switch_nodes_;
|
|
std::unordered_map<OutDataAnchorPtr, std::map<int64_t, std::vector<std::list<NodePtr>>>> cond_node_map_;
|
|
std::unordered_map<NodePtr, std::set<std::string>> switch_node_map_;
|
|
std::unordered_map<std::string, uint32_t> node_num_map_;
|
|
};
|
|
} // namespace ge
|
|
#endif // GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_
|