/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/next_iteration_pass.h" #include "common/ge/ge_util.h" #include "graph/common/omg_util.h" namespace ge { Status NextIterationPass::Run(ComputeGraphPtr graph) { GELOGD("NextIterationPass Enter"); /// Enter-----------+ /// +-> Merge -> Switch <- LoopCond <- Cond /// NextIteration---+ for (auto &node : graph->GetDirectNode()) { const std::string type = node->GetType(); if ((type != ENTER) && (type != REFENTER)) { continue; } if (GroupEnterNode(node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Group enter_node %s failed.", node->GetName().c_str()); return INTERNAL_ERROR; } } if (GroupWithNoBatch(graph) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Group enter_nodes failed without batch_label attr."); return INTERNAL_ERROR; } if (FindWhileGroups() != SUCCESS) { GELOGE(INTERNAL_ERROR, "Find while groups failed."); return INTERNAL_ERROR; } if (!VerifyWhileGroup()) { GELOGE(INTERNAL_ERROR, "Verify while groups failed."); return INTERNAL_ERROR; } if (HandleWhileGroup(graph) != SUCCESS) { GELOGE(FAILED, "Handle while groups failed."); return FAILED; } GELOGD("NextIterationPass Leave"); return SUCCESS; } /// /// @brief Group Enter node /// @param [in] enter_node /// @return Status /// Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { OpDescPtr enter_desc = enter_node->GetOpDesc(); GE_CHECK_NOTNULL(enter_desc); std::string frame_name; if (!ge::AttrUtils::GetStr(enter_desc, ENTER_ATTR_FRAME_NAME, frame_name) || frame_name.empty()) { GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME failed, node: %s", enter_desc->GetName().c_str()); return FAILED; } std::string batch_label; (void)ge::AttrUtils::GetStr(enter_desc, ATTR_NAME_BATCH_LABEL, batch_label); if (batch_label.empty()) { auto frame_iter = frame_enter_map_.find(frame_name); if (frame_iter == frame_enter_map_.end()) { std::vector enter_nodes; enter_nodes.emplace_back(enter_node); frame_enter_map_[frame_name] = enter_nodes; } else { frame_iter->second.emplace_back(enter_node); } return SUCCESS; } auto group_iter = loop_group_map_.find(frame_name); if (group_iter == loop_group_map_.end()) { LoopCondGroupPtr loop_group = MakeShared(); if (loop_group == nullptr) { GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); return FAILED; } loop_group->enter_nodes.emplace_back(enter_node); loop_group_map_[frame_name][batch_label] = loop_group; } else { auto batch_iter = group_iter->second.find(batch_label); if (batch_iter == group_iter->second.end()) { LoopCondGroupPtr loop_group = MakeShared(); if (loop_group == nullptr) { GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); return FAILED; } loop_group->enter_nodes.emplace_back(enter_node); group_iter->second[batch_label] = loop_group; } else { batch_iter->second->enter_nodes.emplace_back(enter_node); } } return SUCCESS; } /// /// @brief Group Enter nodes without batch_label attr /// @param [in] compute_graph /// @return Status /// Status NextIterationPass::GroupWithNoBatch(const ComputeGraphPtr &graph) { if (frame_enter_map_.empty()) { GELOGI("All enter nodes in graph %s has batch_label attr.", graph->GetName().c_str()); return SUCCESS; } for (const auto &item : frame_enter_map_) { const std::string &frame_name = item.first; auto iter = loop_group_map_.find(frame_name); if (iter == loop_group_map_.end()) { LoopCondGroupPtr loop_group = MakeShared(); if (loop_group == nullptr) { GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); return FAILED; } loop_group->enter_nodes = item.second; loop_group_map_[frame_name][""] = loop_group; } else { for (auto &batch_item : iter->second) { for (const auto &enter_node : item.second) { batch_item.second->enter_nodes.emplace_back(enter_node); } } } } return SUCCESS; } /// /// @brief Find while groups /// @return Status /// Status NextIterationPass::FindWhileGroups() { for (const auto &loop_group_iter : loop_group_map_) { const std::string &frame_name = loop_group_iter.first; for (const auto &batch_iter : loop_group_iter.second) { const std::string &batch_label = batch_iter.first; for (const auto &enter_node : batch_iter.second->enter_nodes) { for (const auto &out_node : enter_node->GetOutAllNodes()) { GELOGI("Find while_group for enter_node %s, frame_name:%s, batch_label:%s.", enter_node->GetName().c_str(), frame_name.c_str(), batch_label.c_str()); if ((out_node->GetType() != MERGE) && (out_node->GetType() != REFMERGE)) { continue; } std::string tmp_label; GE_CHECK_NOTNULL(out_node->GetOpDesc()); (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); if (need_skip) { continue; } NodePtr next_node = nullptr; if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Get NextIteration node failed."); return INTERNAL_ERROR; } batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); NodePtr switch_node = nullptr; if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Get Switch node failed."); return INTERNAL_ERROR; } if (switch_node == nullptr) { continue; } NodePtr loop_cond = nullptr; if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Get LoopCond node failed."); return INTERNAL_ERROR; } if (batch_iter.second->loop_cond == nullptr) { batch_iter.second->loop_cond = loop_cond; } else if (batch_iter.second->loop_cond != loop_cond) { GELOGE(FAILED, "Multi LoopCond nodes exist."); return FAILED; } } } } } return SUCCESS; } /// /// @brief Verify if valid /// @return bool /// bool NextIterationPass::VerifyWhileGroup() { // map for (const auto &loop_group_iter : loop_group_map_) { const std::string &frame_name = loop_group_iter.first; if (frame_name.empty()) { GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); return false; } for (const auto &batch_iter : loop_group_iter.second) { if (batch_iter.second->loop_cond == nullptr) { GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); return false; } for (const auto &pair_iter : batch_iter.second->merge_next_pairs) { if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", frame_name.c_str()); return false; } } } } return true; } /// /// @brief Handle while group /// @param [in] graph /// @return Status /// Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { for (const auto &loop_cond_iter : loop_group_map_) { for (const auto &batch_iter : loop_cond_iter.second) { const std::string &cond_name = batch_iter.second->loop_cond->GetName(); GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); if ((enter_active == nullptr) || (next_active == nullptr)) { GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); return INTERNAL_ERROR; } for (const auto &enter_node : batch_iter.second->enter_nodes) { // Enter --> Active if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Add control edge failed."); return INTERNAL_ERROR; } } for (const auto &pair : batch_iter.second->merge_next_pairs) { NodePtr merge_node = pair.first; NodePtr next_node = pair.second; // Active --> Merge if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Add control edge failed."); return INTERNAL_ERROR; } // NextIteration --> Active if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Add control edge failed."); return INTERNAL_ERROR; } // break link between NextIteration and Merge if (BreakNextIteration(next_node, merge_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); return INTERNAL_ERROR; } } if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); return INTERNAL_ERROR; } } } return SUCCESS; } /// /// @brief Create Active Node /// @param [in] graph /// @param [in] name /// @return ge::NodePtr /// NodePtr NextIterationPass::CreateActiveNode(ComputeGraphPtr &graph, const std::string &name) { OpDescPtr op_desc = MakeShared(name, STREAMACTIVE); if (op_desc == nullptr) { return nullptr; } GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str()); NodePtr active_node = graph->AddNode(op_desc); if (active_node == nullptr) { GELOGE(INTERNAL_ERROR, "Create node[%s] failed.", name.c_str()); return nullptr; } if (SetSwitchBranchNodeLabel(active_node, name) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Set attr SWITCH_BRANCH_NODE_LABEL for node: %s failed.", active_node->GetName().c_str()); return nullptr; } return active_node; } /// /// @brief Break NextIteration Link & add name to merge attr /// @param [in] next_node /// @param [in] merge_node /// @return Status /// Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr &merge_node) { if ((merge_node == nullptr) || (next_node == nullptr)) { GELOGE(PARAM_INVALID, "merge node or next node is null."); return PARAM_INVALID; } for (const auto &in_anchor : merge_node->GetAllInDataAnchors()) { OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor(); if ((out_anchor == nullptr) || (out_anchor->GetOwnerNode() != next_node)) { continue; } if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Remove data edge failed, %s->%s.", next_node->GetName().c_str(), merge_node->GetName().c_str()); return INTERNAL_ERROR; } if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); return INTERNAL_ERROR; } } return SUCCESS; } /// /// @brief find target node /// @param [in] node /// @param [in] target_type /// @param [in] is_input /// @param [in] batch_label /// @param [out] target_node /// @return Status /// Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, const std::string &batch_label, NodePtr &target_node) { if (node == nullptr) { GELOGE(PARAM_INVALID, "node is null."); return PARAM_INVALID; } std::vector nodes; if (is_input) { for (const auto &tmp_node : node->GetInDataNodes()) { nodes.emplace_back(tmp_node); } } else { for (const auto &tmp_node : node->GetOutDataNodes()) { nodes.emplace_back(tmp_node); } } for (const auto &tmp_node : nodes) { std::string tmp_label; (void)AttrUtils::GetStr(tmp_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label); bool need_skip = !(batch_label.empty() || tmp_label.empty() || (batch_label == tmp_label)); if (need_skip) { continue; } const std::string type = tmp_node->GetType(); if ((target_type == LOOPCOND) && (type == target_type)) { target_node = tmp_node; break; } else if ((type == target_type) || (type == "Ref" + target_type)) { target_node = tmp_node; break; } } if ((target_type != SWITCH) && (target_node == nullptr)) { GELOGE(INTERNAL_ERROR, "Find node %s failed.", target_type.c_str()); return INTERNAL_ERROR; } return SUCCESS; } /// /// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status NextIterationPass::ClearStatus() { frame_enter_map_.clear(); loop_group_map_.clear(); return SUCCESS; } } // namespace ge