modify hccl op number per-stream

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
pull/11954/head
zhoufeng 4 years ago
parent f1009cb21b
commit 409c50ae94

@ -62,9 +62,31 @@ bool HasRefNodes(const vector<CNodePtr> &moved_backward_cnodes) {
}
return false;
}
StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, uint32_t next_stream_id) {
// pre_stream_id equal to UINT32_MAX means no node active current StreamActive
// next_stream_id equal to UINT32_MAX means current StreamActive active no node
if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
return kInvalid;
}
if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
return kMiddle;
}
if (cur_stream_id == pre_stream_id) {
return kTail;
}
if (cur_stream_id == next_stream_id) {
return kHead;
}
return kInvalid;
}
} // namespace
const uint32_t kHcomMaxTask = 5;
const uint32_t kHcomMaxTask = 4;
const uint32_t kCommonMaxTask = 350;
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
@ -899,7 +921,7 @@ void AscendStreamAssign::InsertStreamActiveForIndependent(const NotNull<KernelGr
std::vector<CNodePtr> update_cnode_list;
auto exe_orders = graph_ptr->execution_order();
// first independent is been actived, active other independent stream
// first independent is been activated, active other independent stream
std::vector<uint32_t> streams;
std::copy(independent_streams.begin(), independent_streams.end(), std::back_inserter(streams));
std::sort(streams.begin(), streams.end());
@ -999,10 +1021,10 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph
auto kind = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrStreamSwitchKind);
if (kind == kIndependentStreamSwitch) {
bool independent_empty = independent_stream_map_.empty();
// if indepdent empty: delete independent streamswitch
// if independent empty: delete independent streamswitch
if (!independent_empty) {
for (const auto &item : independent_stream_map_) {
// first independetn stream id is minimum and order by std map;
// first independent stream id is minimum and order by std map;
auto first_independent_stream = item.first;
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_independent_stream), switch_ptr);
orders->emplace_back(switch_ptr);
@ -1028,7 +1050,7 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph
return;
}
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrTrueBranchStream);
MS_LOG(INFO) << "Swtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) << "; active stream id:" << true_stream_id;
MS_LOG(INFO) << "Switch stream id:" << AnfAlgo::GetStreamId(switch_ptr) << "; active stream id:" << true_stream_id;
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
AnfAlgo::SetStreamId(true_stream_id, active_ptr.get());
vector<uint32_t> active_ids;
@ -1328,7 +1350,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
}
for (const auto &group : groups) {
auto cnode_ptr_list = graph_ptr->execution_order();
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indexs;
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices;
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
auto cur_cnode = cnode_ptr_list[i];
if (!IsHcom(cur_cnode)) {
@ -1346,11 +1368,11 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
continue;
}
if (stream_indexs.empty()) {
stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
if (stream_indices.empty()) {
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
} else {
bool exit = false;
for (auto &item : stream_indexs) {
for (auto &item : stream_indices) {
if (item.first == cur_stream_id) {
item.second.emplace_back(i);
exit = true;
@ -1358,17 +1380,17 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
}
}
if (!exit) {
stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
}
}
}
if (stream_indexs.size() < 2) {
if (stream_indices.size() < 2) {
MS_LOG(INFO) << "Group:" << group
<< "; different stream hcom size is less than 2, no need insert event between them";
continue;
}
InsertEventBetweenHcom(graph_ptr, stream_indexs);
InsertEventBetweenHcom(graph_ptr, stream_indices);
}
}
@ -1474,7 +1496,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), false);
if (target == cnodes.end()) {
MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope()
MS_LOG(DEBUG) << "Independent node[" << (*(it - 1))->fullname_with_scope()
<< "] can't find target for insert recv op, no insert send/recv";
it = cnodes.erase(it);
continue;
@ -1558,16 +1580,16 @@ uint32_t AscendStreamAssign::GetMaxIndexTarget(const NotNull<KernelGraphPtr> &gr
return UINT32_MAX;
}
std::set<uint32_t> indexs;
std::set<uint32_t> indices;
for (const auto &key : independent_targets_) {
auto index = GetIndexByKey(graph_ptr, key);
if (index == UINT32_MAX) {
MS_LOG(EXCEPTION) << "graph has no correspond key";
}
indexs.emplace(index);
indices.emplace(index);
}
return *(std::max_element(indexs.begin(), indexs.end()));
return *(std::max_element(indices.begin(), indices.end()));
}
uint32_t AscendStreamAssign::GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
@ -1623,7 +1645,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
CNodePtr cur_cnode_ptr = nullptr;
auto cnode_ptr_list = graph_ptr->execution_order();
// 1)stream witch kStreamNeedActivedFirst attr should be actived;
// 1)stream witch kStreamNeedActivedFirst attr should be activated;
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
@ -1634,7 +1656,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst);
if (need_active) {
auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first";
MS_LOG(INFO) << "Stream id:" << stream_id << " is need activated at first";
need_first_active_streams_.push_back(stream_id);
}
}
@ -1659,7 +1681,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
}
}
// 4)first stream 0 should be actived first;
// 4)first stream 0 should be activated first;
auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), 0);
if (it == need_first_active_streams_.end()) {
need_first_active_streams_.emplace_back(0);
@ -2025,7 +2047,7 @@ void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraph
for (const auto &item : active_list) {
if (item <= active_current_node) {
MS_LOG(WARNING) << "Actived stream is less than activing stream";
MS_LOG(WARNING) << "Activated stream is less than activing stream";
continue;
}
auto it =
@ -2054,7 +2076,7 @@ void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraph
} else {
for (const auto &stream : active_list) {
if (stream <= cur_stream_id) {
MS_LOG(WARNING) << "Actived stream is less than activing stream";
MS_LOG(WARNING) << "Activated stream is less than activing stream";
continue;
}
auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream);
@ -2131,25 +2153,7 @@ StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGra
break;
}
// pre_stream_id = UINT32_MAX:means no node active current StreamActive
// next_stream_id = UINT32_MAX:means current StreamActive active no node
if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
return kInvalid;
}
if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
return kMiddle;
}
if (cur_stream_id == pre_stream_id) {
return kTail;
}
if (cur_stream_id == next_stream_id) {
return kHead;
}
return kInvalid;
return GetStreamKind(cur_stream_id, pre_stream_id, next_stream_id);
}
uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) {
@ -2172,7 +2176,7 @@ void AscendStreamAssign::PrintStreamRelations() {
for (const auto &item : stream_relations_) {
MS_LOG(INFO) << "Stream:" << item.first;
for (const auto &stream : item.second) {
MS_LOG(INFO) << "--actived stream id:" << stream;
MS_LOG(INFO) << "--activated stream id:" << stream;
}
}
}

Loading…
Cancel
Save