!1476 Bugfix: Missing hccl execution dependency due to wrong attribute type of _parallel_group

From: @xchu42
Reviewed-by: @ji_chen,@wqtshg
Signed-off-by: @ji_chen
pull/1476/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 4c8e5f73c6

@ -318,6 +318,20 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
}
}
if (is_hccl_op) {
for (const auto &src_node : ge_node->GetInControlNodes()) {
auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item);
GELOGD("[%s](%s) Add input control dependent node [%s](%s)",
ge_node->GetName().c_str(),
ge_node->GetType().c_str(),
src_node->GetName().c_str(),
src_node->GetType().c_str());
src_node_item->has_observer = true;
dependent_for_execution.emplace(src_node);
}
}
// cond or branch need to be prepared before the execution of IF or CASE
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) {
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input
@ -2055,8 +2069,9 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
const auto &node = node_item->node;
auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node);
if (executor_type == NodeExecutorManager::ExecutorType::HCCL) {
std::string parallel_group;
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) {
int64_t parallel_group_val = -1;
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group_val)) {
std::string parallel_group = std::to_string(parallel_group_val);
GELOGD("[%s] Got parallel group = [%s]", node_item->NodeName().c_str(), parallel_group.c_str());
parallel_group_to_nodes_[parallel_group].emplace(node_item);
std::set<std::string> group{parallel_group};
@ -2072,8 +2087,9 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name);
GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetAllNodes()) {
std::string parallel_group;
if (AttrUtils::GetStr(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) {
int64_t parallel_group_val = -1;
if (AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group_val)) {
std::string parallel_group = std::to_string(parallel_group_val);
GELOGD("[%s::%s] Got parallel group = %s",
subgraph_name.c_str(),
sub_node->GetName().c_str(),

Loading…
Cancel
Save