/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/parallel_group_pass.h" #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" namespace ge { namespace { const int32_t kMaxRecursionDepth = 10; const int64_t kLoopType = 1; } Status ParallelGroupPass::Run(ComputeGraphPtr graph) { GELOGD("ParallelGroupPass running"); if (graph == nullptr) { GELOGE(PARAM_INVALID, "[Check][Graph]Input param graph is null, skip ParallelGroupPass."); REPORT_INNER_ERROR("E19999", "Input param graph is null, skip ParallelGroupPass."); return PARAM_INVALID; } if (graph->GetParentGraph() != nullptr) { GELOGD("Current graph %s is a subgraph, this pass only support root graph.", graph->GetName().c_str()); return SUCCESS; } if (graph->TopologicalSorting() != GRAPH_SUCCESS) { GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str()); REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.", graph->GetName().c_str()); return FAILED; } std::unordered_set parallel_groups; int depth = 0; if (ProcessGraphGroupNodes(graph, depth, parallel_groups) != SUCCESS) { GELOGE(INTERNAL_ERROR, "[Process][Graph]Process group nodes of graph %s failed.", graph->GetName().c_str()); return INTERNAL_ERROR; } if (graph->TopologicalSorting() != GRAPH_SUCCESS) { GELOGE(FAILED, "[TopoSort][Graph]Graph:%s topological sort failed.", graph->GetName().c_str()); REPORT_CALL_ERROR("E19999", "Graph:%s topological sort failed when ParallelGroupPass run.", graph->GetName().c_str()); return FAILED; } return SUCCESS; } Status ParallelGroupPass::ProcessGraphGroupNodes(ComputeGraphPtr graph, int32_t depth, std::unordered_set ¶llel_groups) { if (depth >= kMaxRecursionDepth) { GELOGE(FAILED, "[Process][SubGraph]There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth); REPORT_INNER_ERROR("E19999", "There are too much subgraphs:%d > %d(max subgraphs)", depth, kMaxRecursionDepth); return FAILED; } std::map> group_nodes; auto candidates = graph->GetDirectNode(); auto root_graph = GraphUtils::FindRootGraph(graph); for (const auto &node : candidates) { OpDescPtr op_desc = node->GetOpDesc(); if (op_desc == nullptr) { continue; } std::string group_name; if (AttrUtils::GetStr(op_desc, ATTR_NAME_PARALLEL_GROUP, group_name)) { group_nodes[group_name].push_back(node); parallel_groups.insert(group_name); GELOGD("Find group node:%s, group_name:%s", node->GetName().c_str(), group_name.c_str()); } const auto &subgraph_name = op_desc->GetSubgraphInstanceNames(); GE_CHECK_NOTNULL(root_graph); for (auto name_iter = subgraph_name.rbegin(); name_iter != subgraph_name.rend(); ++name_iter) { const auto &sub_graph = root_graph->GetSubgraph(*name_iter); GE_CHECK_NOTNULL(sub_graph); // if the pass add control edge for known and unknown graph, then the known graph will become unknown graph // the order between known and unknown graph is guaranteed by dynamic shape executor // so the parallel group pass do nothing for unknown graph if (sub_graph->GetGraphUnknownFlag()) { continue; } std::unordered_set sub_parallel_groups; auto ret = ProcessGraphGroupNodes(sub_graph, depth + 1, sub_parallel_groups); if (ret != SUCCESS) { GELOGE(FAILED, "[Process][SubGraph]Process sub graph %s failed.", sub_graph->GetName().c_str()); return FAILED; } for (const auto &sub_parallel_group : sub_parallel_groups) { parallel_groups.insert(sub_parallel_group); group_nodes[sub_parallel_group].emplace_back(node); } } } std::map, NodePtr>> node_2_switch_merge; if (ProcessGroupNodeInSwitch(graph, node_2_switch_merge) != SUCCESS) { GELOGE(FAILED, "[Process][Node]Process group node in switch failed, graph:%s.", graph->GetName().c_str()); return FAILED; } for (const auto &itr : group_nodes) { const auto &nodes = itr.second; if (nodes.empty()) { continue; } NodePtr pre_node = nodes[0]; NodePtr cur_node = nullptr; for (std::size_t i = 1; i < nodes.size(); i++) { cur_node = nodes[i]; GELOGD("Original add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); if (ReplaceWithSwitchAndMerge(pre_node, cur_node, node_2_switch_merge) != SUCCESS) { GELOGE(FAILED, "[Replace][Node]Replace switch and merges for nodes: %s and %s failed.", pre_node->GetName().c_str(), cur_node->GetName().c_str()); return FAILED; } pre_node = cur_node; } } return SUCCESS; } Status ParallelGroupPass::AddCtrlEdge(NodePtr pre_node, NodePtr cur_node) { if (pre_node == cur_node) { GELOGD("Pre_node and cur_node are same, no need add anchor"); return SUCCESS; } auto in_nodes = cur_node->GetInAllNodes(); for (const auto &node : in_nodes) { if (pre_node == node) { GELOGD("Node:%s and %s already linked", pre_node->GetName().c_str(), cur_node->GetName().c_str()); return SUCCESS; } } GELOGD("Finally add ctrl anchor for node:%s->%s", pre_node->GetName().c_str(), cur_node->GetName().c_str()); return GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), cur_node->GetInControlAnchor()); } Status ParallelGroupPass::ProcessGroupNodeInSwitch(ComputeGraphPtr graph, std::map, NodePtr>> &node_2_switch_merge) { std::string type; auto direct_nodes = graph->GetDirectNode(); for (const auto &node : direct_nodes) { type = node->GetType(); if (type != STREAMSWITCH) { continue; } if (IsBigSmallLoopStreamSwitch(node->GetOpDesc()) || IsWhileStreamSwitch(node->GetOpDesc())) { continue; } std::vector merge_nodes; std::set group_nodes; std::set stream_labels; FindGroupNodeAndMerge(node, group_nodes, merge_nodes, stream_labels); if (merge_nodes.empty() || (!group_nodes.empty() && stream_labels.size() > 1)) { GELOGE(FAILED, "[Process][Node]Cannot find merge node or exist switch nestification, switch node:%s," "merge_vec size:%zu, stream_labels size:%zu, graph:%s.", node->GetName().c_str(), merge_nodes.size(), stream_labels.size(), graph->GetName().c_str()); REPORT_INNER_ERROR("E19999", "Cannot find merge node or exist switch nest, switch node:%s," "merge_vec size: %zu, stream_labels size: %zu, graph:%s.", node->GetName().c_str(), merge_nodes.size(), stream_labels.size(), graph->GetName().c_str()); return FAILED; } std::sort(merge_nodes.begin(), merge_nodes.end(), [] (NodePtr a, NodePtr b) -> bool { return (a->GetOpDesc()->GetId() < b->GetOpDesc()->GetId()); }); NodePtr cast_node = NodeUtils::GetInDataNodeByIndex(*node, 0); GE_CHECK_NOTNULL(cast_node); if (MappingNodeToSwitchAndMerge(group_nodes, merge_nodes, cast_node, node, node_2_switch_merge) != SUCCESS) { GELOGE(FAILED, "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str()); REPORT_CALL_ERROR("E19999", "[Mapping][Node]Mapping node to switch and merge failed, graph:%s.", graph->GetName().c_str()); return FAILED; } } return SUCCESS; } void ParallelGroupPass::FindGroupNodeAndMerge(NodePtr stream_switch_node, std::set &group_nodes, std::vector &merge_nodes, std::set &stream_labels) { std::string type; std::deque candidates; std::set visited; candidates.push_back(stream_switch_node); while (!candidates.empty()) { NodePtr tmp_node = candidates.front(); candidates.pop_front(); for (const auto &out_node : tmp_node->GetOutAllNodes()) { type = out_node->GetType(); if (type == STREAMMERGE) { merge_nodes.emplace_back(out_node); continue; } const auto &op = out_node->GetOpDesc(); if (op != nullptr && op->HasAttr(ATTR_NAME_PARALLEL_GROUP)) { group_nodes.emplace(out_node); } if (visited.count(out_node) > 0) { continue; } candidates.push_back(out_node); visited.insert(out_node); std::string stream_label; if (ge::AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { stream_labels.insert(stream_label); } } } } Status ParallelGroupPass::MappingNodeToSwitchAndMerge(const std::set &group_nodes, const std::vector &merge_nodes, const NodePtr &cast_node, const NodePtr &switch_node, std::map, NodePtr>> &node_2_switch_merge) { for (const auto &group_node : group_nodes) { auto itr = node_2_switch_merge.find(group_node); if (itr != node_2_switch_merge.end()) { auto &tmp = itr->second; auto &switch_set = tmp.first; const auto &merge_node = tmp.second; GELOGD("Find group node: %s in switch %s and merge %s.", group_node->GetName().c_str(), switch_node->GetName().c_str(), merge_node->GetName().c_str()); if (merge_node != merge_nodes.back()) { GELOGE(FAILED, "[Mapping][Node]Has two different merge nodes: %s and %s, graph's structure is invalid", merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str()); REPORT_INNER_ERROR("E19999", "Has two different merge nodes: %s and %s," "graph's structure is invalid", merge_node->GetName().c_str(), merge_nodes.back()->GetName().c_str()); return FAILED; } switch_set.insert(cast_node); } else { node_2_switch_merge.emplace(group_node, std::make_pair(std::set{cast_node}, merge_nodes.back())); } } return SUCCESS; } Status ParallelGroupPass::ReplaceWithSwitchAndMerge(NodePtr pre_node, NodePtr cur_node, const std::map, NodePtr>> &node_2_switch_merge) { auto pre_itr = node_2_switch_merge.find(pre_node); auto cur_itr = node_2_switch_merge.find(cur_node); if (pre_itr != node_2_switch_merge.end()) { if (cur_itr != node_2_switch_merge.end()) { const auto &pre_set = pre_itr->second.first; const auto &cur_set = cur_itr->second.first; if (!HasSameSwitch(pre_set, cur_set)) { pre_node = pre_itr->second.second; for (const auto &switch_node : cur_itr->second.first) { if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", pre_node->GetName().c_str(), switch_node->GetName().c_str()); REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", pre_node->GetName().c_str(), switch_node->GetName().c_str()); return FAILED; } } } return SUCCESS; } else { pre_node = pre_itr->second.second; return AddCtrlEdge(pre_node, cur_node); } } else { if (cur_itr != node_2_switch_merge.end()) { for (const auto &switch_node : cur_itr->second.first) { int64_t pre_id = pre_node->GetOpDesc()->GetId(); int64_t switch_id = switch_node->GetOpDesc()->GetId(); // avoid ring if (pre_id > switch_id) { auto merge_node = cur_itr->second.second; if (AddCtrlEdge(merge_node, pre_node) != SUCCESS) { GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", pre_node->GetName().c_str(), switch_node->GetName().c_str()); REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", pre_node->GetName().c_str(), switch_node->GetName().c_str()); return FAILED; } } else { if (AddCtrlEdge(pre_node, switch_node) != SUCCESS) { GELOGE(FAILED, "[AddEdge][Node]Add edge for nodes: %s->%s failed.", pre_node->GetName().c_str(), switch_node->GetName().c_str()); REPORT_CALL_ERROR("E19999", "[AddEdge][Node]Add edge for nodes: %s->%s failed.", pre_node->GetName().c_str(), switch_node->GetName().c_str()); return FAILED; } } } } else { return AddCtrlEdge(pre_node, cur_node); } } return SUCCESS; } bool ParallelGroupPass::HasSameSwitch(const std::set &switch_set1, const std::set &switch_set2) { for (const auto &node1 : switch_set1) { auto itr = switch_set2.find(node1); if (itr != switch_set2.end()) { return true; } } return false; } bool ParallelGroupPass::IsBigSmallLoopStreamSwitch(OpDescPtr switch_op_desc) { return !AttrUtils::HasAttr(switch_op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG); } bool ParallelGroupPass::IsWhileStreamSwitch(OpDescPtr switch_op_desc) { int64_t stream_switch_type = -1; return (AttrUtils::GetInt(switch_op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, stream_switch_type) && stream_switch_type == kLoopType); } } // namespace ge