From 19d1f804c712112dbca0705e8c9956b71d9e0acc Mon Sep 17 00:00:00 2001 From: chuxing Date: Thu, 8 Apr 2021 11:26:59 +0800 Subject: [PATCH] Bugfix: keep hccl control dependency --- ge/hybrid/model/hybrid_model_builder.cc | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 25dabd78..9e42a91c 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -315,6 +315,20 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s } } + if (is_hccl_op) { + for (const auto &src_node : ge_node->GetInControlNodes()) { + auto src_node_item = MutableNodeItem(src_node); + GE_CHECK_NOTNULL(src_node_item); + GELOGD("[%s](%s) Add input control dependent node [%s](%s)", + ge_node->GetName().c_str(), + ge_node->GetType().c_str(), + src_node->GetName().c_str(), + src_node->GetType().c_str()); + src_node_item->has_observer = true; + dependent_for_execution.emplace(src_node); + } + } + // cond or branch need to be prepared before the execution of IF or CASE if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input @@ -2030,8 +2044,9 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { const auto &node = node_item->node; auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { - std::string parallel_group; - if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { + int64_t parallel_group_val = -1; + if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group_val)) { + std::string parallel_group = std::to_string(parallel_group_val); GELOGD("[%s] Got parallel group = [%s]", node_item->NodeName().c_str(), parallel_group.c_str()); parallel_group_to_nodes_[parallel_group].emplace(node_item); std::set group{parallel_group}; @@ -2047,8 +2062,9 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { auto subgraph = root_graph_->GetSubgraph(subgraph_name); GE_CHECK_NOTNULL(subgraph); for (const auto &sub_node : subgraph->GetAllNodes()) { - std::string parallel_group; - if (AttrUtils::GetStr(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { + int64_t parallel_group_val = -1; + if (AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group_val)) { + std::string parallel_group = std::to_string(parallel_group_val); GELOGD("[%s::%s] Got parallel group = %s", subgraph_name.c_str(), sub_node->GetName().c_str(),