|
|
|
@ -62,9 +62,31 @@ bool HasRefNodes(const vector<CNodePtr> &moved_backward_cnodes) {
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, uint32_t next_stream_id) {
|
|
|
|
|
// pre_stream_id equal to UINT32_MAX means no node active current StreamActive
|
|
|
|
|
// next_stream_id equal to UINT32_MAX means current StreamActive active no node
|
|
|
|
|
if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
|
|
|
|
|
return kInvalid;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
|
|
|
|
|
return kMiddle;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cur_stream_id == pre_stream_id) {
|
|
|
|
|
return kTail;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cur_stream_id == next_stream_id) {
|
|
|
|
|
return kHead;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return kInvalid;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
const uint32_t kHcomMaxTask = 5;
|
|
|
|
|
const uint32_t kHcomMaxTask = 4;
|
|
|
|
|
const uint32_t kCommonMaxTask = 350;
|
|
|
|
|
|
|
|
|
|
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
|
|
@ -899,7 +921,7 @@ void AscendStreamAssign::InsertStreamActiveForIndependent(const NotNull<KernelGr
|
|
|
|
|
std::vector<CNodePtr> update_cnode_list;
|
|
|
|
|
auto exe_orders = graph_ptr->execution_order();
|
|
|
|
|
|
|
|
|
|
// first independent is been actived, active other independent stream
|
|
|
|
|
// first independent is been activated, active other independent stream
|
|
|
|
|
std::vector<uint32_t> streams;
|
|
|
|
|
std::copy(independent_streams.begin(), independent_streams.end(), std::back_inserter(streams));
|
|
|
|
|
std::sort(streams.begin(), streams.end());
|
|
|
|
@ -999,10 +1021,10 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph
|
|
|
|
|
auto kind = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrStreamSwitchKind);
|
|
|
|
|
if (kind == kIndependentStreamSwitch) {
|
|
|
|
|
bool independent_empty = independent_stream_map_.empty();
|
|
|
|
|
// if indepdent empty: delete independent streamswitch
|
|
|
|
|
// if independent empty: delete independent streamswitch
|
|
|
|
|
if (!independent_empty) {
|
|
|
|
|
for (const auto &item : independent_stream_map_) {
|
|
|
|
|
// first independetn stream id is minimum and order by std map;
|
|
|
|
|
// first independent stream id is minimum and order by std map;
|
|
|
|
|
auto first_independent_stream = item.first;
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_independent_stream), switch_ptr);
|
|
|
|
|
orders->emplace_back(switch_ptr);
|
|
|
|
@ -1028,7 +1050,7 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrTrueBranchStream);
|
|
|
|
|
MS_LOG(INFO) << "Swtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) << "; active stream id:" << true_stream_id;
|
|
|
|
|
MS_LOG(INFO) << "Switch stream id:" << AnfAlgo::GetStreamId(switch_ptr) << "; active stream id:" << true_stream_id;
|
|
|
|
|
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
|
|
|
|
AnfAlgo::SetStreamId(true_stream_id, active_ptr.get());
|
|
|
|
|
vector<uint32_t> active_ids;
|
|
|
|
@ -1328,7 +1350,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
|
|
|
|
|
}
|
|
|
|
|
for (const auto &group : groups) {
|
|
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
|
|
|
|
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indexs;
|
|
|
|
|
std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices;
|
|
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
|
|
|
|
|
auto cur_cnode = cnode_ptr_list[i];
|
|
|
|
|
if (!IsHcom(cur_cnode)) {
|
|
|
|
@ -1346,11 +1368,11 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (stream_indexs.empty()) {
|
|
|
|
|
stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
|
|
|
|
if (stream_indices.empty()) {
|
|
|
|
|
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
|
|
|
|
} else {
|
|
|
|
|
bool exit = false;
|
|
|
|
|
for (auto &item : stream_indexs) {
|
|
|
|
|
for (auto &item : stream_indices) {
|
|
|
|
|
if (item.first == cur_stream_id) {
|
|
|
|
|
item.second.emplace_back(i);
|
|
|
|
|
exit = true;
|
|
|
|
@ -1358,17 +1380,17 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!exit) {
|
|
|
|
|
stream_indexs.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
|
|
|
|
stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (stream_indexs.size() < 2) {
|
|
|
|
|
if (stream_indices.size() < 2) {
|
|
|
|
|
MS_LOG(INFO) << "Group:" << group
|
|
|
|
|
<< "; different stream hcom size is less than 2, no need insert event between them";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
InsertEventBetweenHcom(graph_ptr, stream_indexs);
|
|
|
|
|
InsertEventBetweenHcom(graph_ptr, stream_indices);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1474,7 +1496,7 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
|
|
|
|
|
|
|
|
|
|
auto target = FindTargetOp(it, cnodes.end(), *(it - 1), false);
|
|
|
|
|
if (target == cnodes.end()) {
|
|
|
|
|
MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope()
|
|
|
|
|
MS_LOG(DEBUG) << "Independent node[" << (*(it - 1))->fullname_with_scope()
|
|
|
|
|
<< "] can't find target for insert recv op, no insert send/recv";
|
|
|
|
|
it = cnodes.erase(it);
|
|
|
|
|
continue;
|
|
|
|
@ -1558,16 +1580,16 @@ uint32_t AscendStreamAssign::GetMaxIndexTarget(const NotNull<KernelGraphPtr> &gr
|
|
|
|
|
return UINT32_MAX;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<uint32_t> indexs;
|
|
|
|
|
std::set<uint32_t> indices;
|
|
|
|
|
for (const auto &key : independent_targets_) {
|
|
|
|
|
auto index = GetIndexByKey(graph_ptr, key);
|
|
|
|
|
if (index == UINT32_MAX) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "graph has no correspond key";
|
|
|
|
|
}
|
|
|
|
|
indexs.emplace(index);
|
|
|
|
|
indices.emplace(index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return *(std::max_element(indexs.begin(), indexs.end()));
|
|
|
|
|
return *(std::max_element(indices.begin(), indices.end()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t AscendStreamAssign::GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
|
|
@ -1623,7 +1645,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
|
|
|
|
|
CNodePtr cur_cnode_ptr = nullptr;
|
|
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
|
|
|
|
|
|
|
|
|
// 1)stream witch kStreamNeedActivedFirst attr should be actived;
|
|
|
|
|
// 1)stream witch kStreamNeedActivedFirst attr should be activated;
|
|
|
|
|
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
|
|
|
|
|
cur_cnode_ptr = cnode_ptr_list[i];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
|
|
|
@ -1634,7 +1656,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
|
|
|
|
|
auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst);
|
|
|
|
|
if (need_active) {
|
|
|
|
|
auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
|
|
|
|
|
MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first";
|
|
|
|
|
MS_LOG(INFO) << "Stream id:" << stream_id << " is need activated at first";
|
|
|
|
|
need_first_active_streams_.push_back(stream_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1659,7 +1681,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 4)first stream 0 should be actived first;
|
|
|
|
|
// 4)first stream 0 should be activated first;
|
|
|
|
|
auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), 0);
|
|
|
|
|
if (it == need_first_active_streams_.end()) {
|
|
|
|
|
need_first_active_streams_.emplace_back(0);
|
|
|
|
@ -2025,7 +2047,7 @@ void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraph
|
|
|
|
|
|
|
|
|
|
for (const auto &item : active_list) {
|
|
|
|
|
if (item <= active_current_node) {
|
|
|
|
|
MS_LOG(WARNING) << "Actived stream is less than activing stream";
|
|
|
|
|
MS_LOG(WARNING) << "Activated stream is less than activing stream";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto it =
|
|
|
|
@ -2054,7 +2076,7 @@ void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraph
|
|
|
|
|
} else {
|
|
|
|
|
for (const auto &stream : active_list) {
|
|
|
|
|
if (stream <= cur_stream_id) {
|
|
|
|
|
MS_LOG(WARNING) << "Actived stream is less than activing stream";
|
|
|
|
|
MS_LOG(WARNING) << "Activated stream is less than activing stream";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream);
|
|
|
|
@ -2131,25 +2153,7 @@ StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGra
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// pre_stream_id = UINT32_MAX:means no node active current StreamActive
|
|
|
|
|
// next_stream_id = UINT32_MAX:means current StreamActive active no node
|
|
|
|
|
if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
|
|
|
|
|
return kInvalid;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
|
|
|
|
|
return kMiddle;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cur_stream_id == pre_stream_id) {
|
|
|
|
|
return kTail;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cur_stream_id == next_stream_id) {
|
|
|
|
|
return kHead;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return kInvalid;
|
|
|
|
|
return GetStreamKind(cur_stream_id, pre_stream_id, next_stream_id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) {
|
|
|
|
@ -2172,7 +2176,7 @@ void AscendStreamAssign::PrintStreamRelations() {
|
|
|
|
|
for (const auto &item : stream_relations_) {
|
|
|
|
|
MS_LOG(INFO) << "Stream:" << item.first;
|
|
|
|
|
for (const auto &stream : item.second) {
|
|
|
|
|
MS_LOG(INFO) << "--actived stream id:" << stream;
|
|
|
|
|
MS_LOG(INFO) << "--activated stream id:" << stream;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|