|
|
|
@ -985,61 +985,56 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
|
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
|
|
|
|
// key:group id, key: stream id, value:hcom index
|
|
|
|
|
std::map<std::string, std::vector<std::pair<uint32_t, vector<size_t>>>> group_hcom_index;
|
|
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
|
|
|
|
|
auto cur_cnode = cnode_ptr_list[i];
|
|
|
|
|
if (!IsHcom(cur_cnode)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) {
|
|
|
|
|
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr";
|
|
|
|
|
}
|
|
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup);
|
|
|
|
|
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
|
|
|
|
|
<< "; stream id:" << cur_stream_id;
|
|
|
|
|
auto iter = group_hcom_index.find(group_name);
|
|
|
|
|
if (iter == group_hcom_index.end()) {
|
|
|
|
|
std::vector<std::pair<uint32_t, vector<size_t>>> hcom_index;
|
|
|
|
|
hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
|
|
|
|
group_hcom_index[group_name] = hcom_index;
|
|
|
|
|
} else {
|
|
|
|
|
auto &hcom_index = iter->second;
|
|
|
|
|
bool exit = false;
|
|
|
|
|
for (auto &item : hcom_index) {
|
|
|
|
|
if (item.first == cur_stream_id) {
|
|
|
|
|
item.second.emplace_back(i);
|
|
|
|
|
exit = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
if (group_hcom_graph_map_.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::vector<string> groups;
|
|
|
|
|
for (const auto &item : group_hcom_graph_map_) {
|
|
|
|
|
groups.emplace_back(item.first);
|
|
|
|
|
}
|
|
|
|
|
for (const auto &group : groups) {
|
|
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
|
|
|
|
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indexs;
|
|
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
|
|
|
|
|
auto cur_cnode = cnode_ptr_list[i];
|
|
|
|
|
if (!IsHcom(cur_cnode)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (!exit) {
|
|
|
|
|
hcom_index.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
|
|
|
|
|
|
|
|
|
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) {
|
|
|
|
|
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr";
|
|
|
|
|
}
|
|
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup);
|
|
|
|
|
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
|
|
|
|
|
<< "; stream id:" << cur_stream_id;
|
|
|
|
|
if (group_name != group) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto &hcom_index : group_hcom_index) {
|
|
|
|
|
MS_LOG(DEBUG) << "Group:" << hcom_index.first;
|
|
|
|
|
for (const auto &item : hcom_index.second) {
|
|
|
|
|
MS_LOG(DEBUG) << "stream id:" << item.first;
|
|
|
|
|
for (const auto &index : item.second) {
|
|
|
|
|
MS_LOG(DEBUG) << "hcom index:" << index;
|
|
|
|
|
if (stream_indexs.empty()) {
|
|
|
|
|
stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
|
|
|
|
} else {
|
|
|
|
|
bool exit = false;
|
|
|
|
|
for (auto &item : stream_indexs) {
|
|
|
|
|
if (item.first == cur_stream_id) {
|
|
|
|
|
item.second.emplace_back(i);
|
|
|
|
|
exit = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!exit) {
|
|
|
|
|
stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto &hcom_index : group_hcom_index) {
|
|
|
|
|
if (hcom_index.second.size() < 2) {
|
|
|
|
|
MS_LOG(INFO) << "Group:" << hcom_index.first
|
|
|
|
|
if (stream_indexs.size() < 2) {
|
|
|
|
|
MS_LOG(INFO) << "Group:" << group
|
|
|
|
|
<< "; different stream hcom size is less than 2, no need insert event between them";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
InsertEventBetweenHcom(graph_ptr, hcom_index.second);
|
|
|
|
|
MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num();
|
|
|
|
|
InsertEventBetweenHcom(graph_ptr, stream_indexs);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|