/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/control_trigger_pass.h" #include #include "common/ge/ge_util.h" #include "graph/common/omg_util.h" #include "graph/utils/type_utils.h" namespace ge { Status ControlTriggerPass::Run(ComputeGraphPtr graph) { GELOGD("ControlTriggerPass Enter"); for (NodePtr &node : graph->GetDirectNode()) { if (node->GetType() != CONTROLTRIGGER) { continue; } auto in_ctrl_nodes = node->GetInControlNodes(); for (NodePtr &in_ctrl_node : in_ctrl_nodes) { if (HandleDynamicCtrlEdges(graph, node, in_ctrl_node) != SUCCESS) { GELOGE(FAILED, "HandleDynamicCtrlEdges for %s->%s fail.", in_ctrl_node->GetName().c_str(), node->GetName().c_str()); return FAILED; } } } GELOGD("ControlTriggerPass Leave"); return SUCCESS; } /// /// @brief Handle input ctrl edges for ControlTrigger node /// @param [in] graph /// @param [in] node /// @param [in] in_ctrl_node /// @return Status /// Status ControlTriggerPass::HandleDynamicCtrlEdges(ComputeGraphPtr &graph, NodePtr &node, NodePtr &in_ctrl_node) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(in_ctrl_node); GELOGI("HandleDynamicCtrlEdges: node=%s, in_ctrl_node=%s", node->GetName().c_str(), in_ctrl_node->GetName().c_str()); NodePtr switch_node = nullptr; bool branch_flag = false; if (FindSwitchNode(in_ctrl_node, switch_node, branch_flag) != SUCCESS) { GELOGE(FAILED, "FindSwitchNode fail."); return FAILED; } if (switch_node == nullptr) { GELOGI("Not find valid switch node."); return SUCCESS; } auto iter1 = control_trigger_map_.find(node); if (iter1 != control_trigger_map_.end()) { auto iter2 = iter1->second.find(switch_cond_map_[switch_node]); if (iter2 != iter1->second.end()) { NodePtr constant = (branch_flag ? iter2->second.second : iter2->second.first); if ((GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), node->GetInControlAnchor()) != GRAPH_SUCCESS) || (GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), constant->GetInControlAnchor()) != GRAPH_SUCCESS)) { GELOGE(FAILED, "Replace ctrl edge fail, %s->%s, %s->%s.", in_ctrl_node->GetName().c_str(), node->GetName().c_str(), in_ctrl_node->GetName().c_str(), constant->GetName().c_str()); return FAILED; } GELOGI("No need to insert new branch."); return SUCCESS; } } if (InsertOppositeBranch(graph, node, in_ctrl_node, switch_node, branch_flag) != SUCCESS) { GELOGE(FAILED, "InsertOppositeBranch fail."); return FAILED; } return SUCCESS; } /// /// @brief Find switch_node for ControlTrigger node /// @param [in] node /// @param [out] switch_node /// @param [out] branch_flag /// @return Status /// Status ControlTriggerPass::FindSwitchNode(const NodePtr &node, NodePtr &switch_node, bool &branch_flag) { std::set> handle_nodes; // {node, >} std::stack>>> nodes; nodes.push(std::make_pair(node, std::make_pair(UINT32_MAX, std::make_pair(0, 0)))); std::set> in_nodes; while (!nodes.empty()) { auto iter = nodes.top(); NodePtr tmp_node = iter.first; GE_CHECK_NOTNULL(tmp_node); nodes.pop(); uint32_t index = iter.second.first; auto num_pair = iter.second.second; if (handle_nodes.count(std::make_pair(tmp_node, index)) > 0) { continue; } switch (TransferNodeType(tmp_node, index)) { case kCondSwitch: if (num_pair.first == 0) { switch_node = tmp_node; branch_flag = (index == SWITCH_TRUE_OUTPUT); GELOGI("FindSwitchNode succ, switch_node=%s, idx=%u", switch_node->GetName().c_str(), index); return SUCCESS; } num_pair.first--; break; case kCondMerge: num_pair.first++; break; case kLoopSwitchT: GELOGI("in while_body, no need handle"); return SUCCESS; case kLoopSwitchF: num_pair.second++; break; case kEnter: if (num_pair.second > 0) { num_pair.second--; } break; case kNotControlOp: break; default: GELOGE(FAILED, "invalid type"); return FAILED; } GetInNodes(tmp_node, in_nodes); for (auto &node_idx : in_nodes) { nodes.push(std::make_pair(node_idx.first, std::make_pair(node_idx.second, num_pair))); } (void)handle_nodes.insert(std::make_pair(tmp_node, index)); } return SUCCESS; } /// /// @brief Check if need insert opposite branch /// @param [in] node /// @param [in] index /// @return ControlNodeType /// ControlNodeType ControlTriggerPass::TransferNodeType(const NodePtr &node, uint32_t index) { const std::string type = node->GetType(); if ((type == SWITCH) || (type == REFSWITCH)) { if ((index != SWITCH_TRUE_OUTPUT) && (index != SWITCH_FALSE_OUTPUT)) { GELOGI("TransferNodeType: neither true nor false branch."); return kNotControlOp; } if (FindPredInput(node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "FindPredInput fail, switch_node: %s.", node->GetName().c_str()); return kInvalidType; } NodePtr pred_node = switch_cond_map_[node]; bool branch_flag = (index == SWITCH_TRUE_OUTPUT); if (pred_node->GetType() != LOOPCOND) { GELOGI("TransferNodeType: kCondSwitch node=%s, idx=%u", node->GetName().c_str(), index); return kCondSwitch; } else { GELOGI("TransferNodeType: kLoopSwitch node=%s, idx=%u", node->GetName().c_str(), index); return branch_flag ? kLoopSwitchT : kLoopSwitchF; } } else if ((type == MERGE) || (type == REFMERGE)) { OpDescPtr merge_desc = node->GetOpDesc(); if (merge_desc == nullptr) { GELOGE(INTERNAL_ERROR, "FindPredInput fail, merge_desc is null, merge_node: %s.", node->GetName().c_str()); return kInvalidType; } if (!merge_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { return kCondMerge; } } else if ((type == ENTER) || (type == REFENTER)) { return kEnter; } return kNotControlOp; } /// /// @brief Get in_node & idx pairs /// @param [in] node /// @param [out] in_nodes /// @return void /// void ControlTriggerPass::GetInNodes(const NodePtr &node, std::set> &in_nodes) { in_nodes.clear(); for (auto &in_ctrl_node : node->GetInControlNodes()) { (void)in_nodes.insert(std::make_pair(in_ctrl_node, UINT32_MAX)); } for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { continue; } (void)in_nodes.insert(std::make_pair(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx())); } return; } /// /// @brief Insert opposite branch for ControlTrigger /// @param [in] graph /// @param [in] ControlTrigger node /// @param [in] in_ctrl_node /// @param [in] switch_node /// @param [in] branch_flag /// @return Status /// Status ControlTriggerPass::InsertOppositeBranch(ComputeGraphPtr &graph, NodePtr &node, NodePtr &in_ctrl_node, NodePtr &switch_node, bool branch_flag) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(in_ctrl_node); GE_CHECK_NOTNULL(switch_node); OpDescPtr switch_desc = switch_node->GetOpDesc(); GE_CHECK_NOTNULL(switch_desc); GeTensorDesc data_desc(GeShape(), FORMAT_NCHW, DT_INT32); NodePtr merge_node = InsertMergeNode(graph, node, in_ctrl_node, data_desc); if (merge_node == nullptr) { GELOGE(FAILED, "InsertMergeNode fail."); return FAILED; } NodePtr const_f = InsertConstNode(graph, merge_node, data_desc, false); NodePtr const_t = InsertConstNode(graph, merge_node, data_desc, true); if ((const_f == nullptr) || (const_t == nullptr)) { GELOGE(FAILED, "InsertConstNode fail."); return FAILED; } NodePtr orig_const = branch_flag ? const_t : const_f; NodePtr new_const = !branch_flag ? const_t : const_f; uint32_t new_idx = branch_flag ? SWITCH_FALSE_OUTPUT : SWITCH_TRUE_OUTPUT; const std::string identity_name = switch_desc->GetName() + "_" + IDENTITY; NodePtr identity_node = InsertIdentityNode(graph, identity_name, switch_desc->GetOutputDesc(new_idx)); if (identity_node == nullptr) { GELOGE(FAILED, "InsertIdentityNode fail."); return FAILED; } if (GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), orig_const->GetInControlAnchor()) != GRAPH_SUCCESS) { GELOGE(FAILED, "Add in ctrl edge fail, %s->%s.", in_ctrl_node->GetName().c_str(), orig_const->GetName().c_str()); return FAILED; } if (GraphUtils::AddEdge(switch_node->GetOutDataAnchor(new_idx), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { GELOGE(FAILED, "Add in data edge fail, %s->%s.", switch_desc->GetName().c_str(), identity_node->GetName().c_str()); return FAILED; } if (GraphUtils::AddEdge(identity_node->GetOutControlAnchor(), new_const->GetInControlAnchor()) != GRAPH_SUCCESS) { GELOGE(FAILED, "Add in ctrl edge fail, %s->%s.", identity_node->GetName().c_str(), new_const->GetName().c_str()); return FAILED; } auto pred_const = std::make_pair(switch_cond_map_[switch_node], std::make_pair(const_f, const_t)); auto iter = control_trigger_map_.find(node); if (iter == control_trigger_map_.end()) { control_trigger_map_[node] = {pred_const}; } else { if (!iter->second.insert(pred_const).second) { GELOGE(FAILED, "control_trigger_map_ insert failed."); return FAILED; } } return SUCCESS; } /// /// @brief Insert Merge Node /// @param [in] graph /// @param [in] node /// @param [in] in_ctrl_node /// @param [in] data_desc /// @return NodePtr /// NodePtr ControlTriggerPass::InsertMergeNode(ComputeGraphPtr &graph, NodePtr &node, NodePtr &in_ctrl_node, const GeTensorDesc &data_desc) { const std::string name = node->GetName() + "_" + MERGE; OpDescPtr op_desc = MakeShared(name, MERGE); if (op_desc == nullptr) { GELOGE(FAILED, "Create Merge op %s: create op_desc fail.", name.c_str()); return nullptr; } if ((op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) || (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) || (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) || (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS)) { GELOGE(INTERNAL_ERROR, "Create Merge op %s: add input/output desc fail.", name.c_str()); return nullptr; } GELOGI("Create Merge op:%s.", name.c_str()); NodePtr merge_node = graph->AddNode(op_desc); if (merge_node == nullptr) { GELOGE(INTERNAL_ERROR, "Create Merge op %s fail.", name.c_str()); return nullptr; } if ((GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), node->GetInControlAnchor()) != GRAPH_SUCCESS) || (GraphUtils::AddEdge(merge_node->GetOutControlAnchor(), node->GetInControlAnchor()) != GRAPH_SUCCESS)) { GELOGE(FAILED, "Replace ctrl edge fail, %s->%s, %s->%s", in_ctrl_node->GetName().c_str(), node->GetName().c_str(), merge_node->GetName().c_str(), node->GetName().c_str()); return nullptr; } return merge_node; } /// /// @brief Insert Const Node /// @param [in] graph /// @param [in] merge_node /// @param [in] data_desc /// @param [in] flag /// @return NodePtr /// NodePtr ControlTriggerPass::InsertConstNode(ComputeGraphPtr &graph, NodePtr &merge_node, const GeTensorDesc &data_desc, bool flag) { const std::string name = merge_node->GetName() + "_" + CONSTANT + (flag ? "_t" : "_f"); OpDescPtr op_desc = MakeShared(name, CONSTANT); if (op_desc == nullptr) { GELOGE(FAILED, "Create Const op %s: create op_desc fail.", name.c_str()); return nullptr; } int32_t value = 0; GeTensorPtr const_value = MakeShared(data_desc, reinterpret_cast(&value), sizeof(int32_t)); if (const_value == nullptr) { GELOGE(FAILED, "Create tensor fail."); return nullptr; } if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, const_value)) { GELOGE(INTERNAL_ERROR, "Create Const op %s: set attr ATTR_NAME_WEIGHTS fail.", name.c_str()); return nullptr; } if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Create Const op %s: add output desc fail.", name.c_str()); return nullptr; } GELOGI("Create Const op: %s", name.c_str()); NodePtr const_node = graph->AddNode(op_desc); if (const_node == nullptr) { GELOGE(INTERNAL_ERROR, "Create Const op %s fail.", name.c_str()); return nullptr; } uint32_t out_idx = (flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT); if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), merge_node->GetInDataAnchor(out_idx)) != GRAPH_SUCCESS) { GELOGE(FAILED, "Add in data edge fail, %s->%s", const_node->GetName().c_str(), merge_node->GetName().c_str()); return nullptr; } return const_node; } /// /// @brief Insert Identity Node /// @param [in] graph /// @param [in] name /// @param [in] data_desc /// @return NodePtr /// NodePtr ControlTriggerPass::InsertIdentityNode(ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &data_desc) { OpDescPtr op_desc = MakeShared(name, IDENTITY); if (op_desc == nullptr) { GELOGE(FAILED, "Create Identity op %s: create op_desc fail.", name.c_str()); return nullptr; } if ((op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) || (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS)) { GELOGE(INTERNAL_ERROR, "Create Identity op %s: add input/output desc fail.", name.c_str()); return nullptr; } GELOGI("Create Identity op:%s.", name.c_str()); NodePtr identity_node = graph->AddNode(op_desc); if (identity_node == nullptr) { GELOGE(INTERNAL_ERROR, "Create Identity op %s fail.", name.c_str()); return nullptr; } return identity_node; } /// /// @brief Find pred_input of switch_node /// @param [in] switch_node /// @param [in] name /// @param [in] data_desc /// @return Status /// Status ControlTriggerPass::FindPredInput(const NodePtr &switch_node) { if (switch_node == nullptr) { GELOGE(INTERNAL_ERROR, "switch_node is null"); return INTERNAL_ERROR; } InDataAnchorPtr in_cond_anchor = switch_node->GetInDataAnchor(SWITCH_PRED_INPUT); if (in_cond_anchor == nullptr) { GELOGE(INTERNAL_ERROR, "in_cond_anchor is nullptr, node: %s.", switch_node->GetName().c_str()); return INTERNAL_ERROR; } OutDataAnchorPtr pred_cond_anchor = in_cond_anchor->GetPeerOutAnchor(); if (pred_cond_anchor == nullptr) { GELOGE(INTERNAL_ERROR, "pred_cond_anchor is nullptr, node: %s.", switch_node->GetName().c_str()); return INTERNAL_ERROR; } switch_cond_map_[switch_node] = pred_cond_anchor->GetOwnerNode(); return SUCCESS; } /// /// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status ControlTriggerPass::ClearStatus() { switch_cond_map_.clear(); control_trigger_map_.clear(); return SUCCESS; } } // namespace ge