|
|
|
@ -22,6 +22,8 @@
|
|
|
|
|
#include "ir/manager.h"
|
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
|
#include "utils/ms_utils.h"
|
|
|
|
|
#include "frontend/parallel/context.h"
|
|
|
|
|
#include "frontend/parallel/device_manager.h"
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "runtime/device/kernel_adjust.h"
|
|
|
|
|
#include "backend/optimizer/common/helper.h"
|
|
|
|
@ -36,6 +38,79 @@ namespace mindspore {
|
|
|
|
|
namespace device {
|
|
|
|
|
namespace ascend {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr uint32_t kDeviceNumOfServer = 8;
|
|
|
|
|
constexpr uint32_t kDeviceNumThreshold = 1024;
|
|
|
|
|
|
|
|
|
|
constexpr uint32_t kMaxStreamNum = 1024;
|
|
|
|
|
constexpr uint32_t kHcomSecondaryStreamNum = 3;
|
|
|
|
|
|
|
|
|
|
constexpr uint32_t kMaxTaskNumPerStream = 1010;
|
|
|
|
|
constexpr uint32_t kMaxCommonNodeNumPerStream = 350;
|
|
|
|
|
|
|
|
|
|
constexpr uint32_t kTaskNumPerHcomNode = 200;
|
|
|
|
|
constexpr uint32_t kTaskNumPerWorldHcomNode = 250;
|
|
|
|
|
constexpr uint32_t kTaskNumPerSameServerHcomNode = 125;
|
|
|
|
|
constexpr uint32_t kTaskNumPerHcomSendRecvNode = 15;
|
|
|
|
|
|
|
|
|
|
bool IsSameServer(const std::vector<uint32_t> &rank_ids) {
|
|
|
|
|
auto min_iter = min_element(rank_ids.begin(), rank_ids.end());
|
|
|
|
|
uint32_t min = (min_iter != rank_ids.end()) ? *min_iter : 0;
|
|
|
|
|
auto max_iter = max_element(rank_ids.begin(), rank_ids.end());
|
|
|
|
|
uint32_t max = (max_iter != rank_ids.end()) ? *max_iter : 0;
|
|
|
|
|
return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
string GetHcomGroup(const CNodePtr &cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
|
|
|
|
|
MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t GetHcomTaskNum(const CNodePtr &cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
|
|
|
|
|
MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (parallel::g_device_manager == nullptr) {
|
|
|
|
|
MS_LOG(INFO) << "Device manager is nullptr.";
|
|
|
|
|
return kTaskNumPerHcomNode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
|
|
|
|
if (node_name == kHcomSendOpName || node_name == kReceiveOpName) {
|
|
|
|
|
return kTaskNumPerHcomSendRecvNode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
|
|
|
|
|
auto device_num = parallel::ParallelContext::GetInstance()->device_num();
|
|
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
|
|
|
|
|
auto group_info = parallel::g_device_manager->group_info();
|
|
|
|
|
for (const auto &info : group_info) {
|
|
|
|
|
if (info.first != group_name) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
const auto &rank_ids = info.second;
|
|
|
|
|
if (IsSameServer(rank_ids)) {
|
|
|
|
|
return kTaskNumPerSameServerHcomNode;
|
|
|
|
|
} else if (rank_ids.size() == static_cast<size_t>(device_num) && device_num >= kDeviceNumThreshold) {
|
|
|
|
|
return kTaskNumPerWorldHcomNode;
|
|
|
|
|
} else {
|
|
|
|
|
return kTaskNumPerHcomNode;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// world group is not in group_info.
|
|
|
|
|
if (device_num >= kDeviceNumThreshold) {
|
|
|
|
|
return kTaskNumPerWorldHcomNode;
|
|
|
|
|
} else {
|
|
|
|
|
return kTaskNumPerHcomNode;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr GetHcomAndOverflowMarker(const NotNull<KernelGraphPtr> &graph_ptr, vector<CNodePtr> *hcom_nodes) {
|
|
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
|
|
|
|
CNodePtr overflow_marker = nullptr;
|
|
|
|
@ -90,9 +165,6 @@ StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, u
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
const uint32_t kHcomMaxTask = 4;
|
|
|
|
|
const uint32_t kCommonMaxTask = 350;
|
|
|
|
|
|
|
|
|
|
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
|
|
|
if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
|
|
|
|
|
Reset();
|
|
|
|
@ -110,6 +182,10 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
|
|
|
|
|
AdjustAtomicAddrCleanOrder(graph_ptr);
|
|
|
|
|
|
|
|
|
|
GetNeedActiveStreams(graph_ptr);
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Before check resource assign";
|
|
|
|
|
graph_ptr->PrintGraphExecuteOrder();
|
|
|
|
|
|
|
|
|
|
CheckResourceAssign(graph_ptr);
|
|
|
|
|
MS_LOG(INFO) << "After finish stream assign";
|
|
|
|
|
#ifdef ENABLE_DUMP_IR
|
|
|
|
@ -478,15 +554,26 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
|
|
|
|
|
|
|
|
|
|
AssignCommonStreamId(cur_cnode_ptr);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num();
|
|
|
|
|
|
|
|
|
|
auto common_stream_num = resource_manager.get_cur_stream_num();
|
|
|
|
|
|
|
|
|
|
if (exit_hcom) {
|
|
|
|
|
AssignHcom(graph_ptr);
|
|
|
|
|
}
|
|
|
|
|
auto hcom_stream_num = resource_manager.get_cur_stream_num() - common_stream_num;
|
|
|
|
|
|
|
|
|
|
if (exit_independent) {
|
|
|
|
|
AssignIndependent(graph_ptr);
|
|
|
|
|
}
|
|
|
|
|
auto independent_stream_num = resource_manager.get_cur_stream_num() - common_stream_num - hcom_stream_num;
|
|
|
|
|
auto total_stream_num = resource_manager.get_cur_stream_num() + hcom_stream_num * kHcomSecondaryStreamNum;
|
|
|
|
|
MS_LOG(INFO) << "Total stream number: " << total_stream_num << ", common stream number: " << common_stream_num
|
|
|
|
|
<< ", hcom stream number: " << hcom_stream_num << "*" << kHcomSecondaryStreamNum + 1
|
|
|
|
|
<< ", independent stream number: " << independent_stream_num << ".";
|
|
|
|
|
|
|
|
|
|
if (total_stream_num > kMaxStreamNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Total stream number " << total_stream_num << " exceeds the limit of " << kMaxStreamNum << ".";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num();
|
|
|
|
|
}
|
|
|
|
@ -507,7 +594,7 @@ void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
|
|
|
|
|
AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get());
|
|
|
|
|
common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1));
|
|
|
|
|
} else {
|
|
|
|
|
if (it->second < kCommonMaxTask) {
|
|
|
|
|
if (it->second < kMaxCommonNodeNumPerStream) {
|
|
|
|
|
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
|
|
|
|
|
it->second++;
|
|
|
|
|
} else {
|
|
|
|
@ -529,10 +616,7 @@ void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsHcom(cur_cnode_ptr)) {
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) {
|
|
|
|
|
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode_ptr->DebugString() << " has no group attr";
|
|
|
|
|
}
|
|
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup);
|
|
|
|
|
auto group_name = GetHcomGroup(cur_cnode_ptr);
|
|
|
|
|
auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
|
|
|
|
|
auto iter = group_graph_nodes_map.find(group_name);
|
|
|
|
|
if (iter == group_graph_nodes_map.end()) {
|
|
|
|
@ -576,6 +660,8 @@ void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
|
|
|
|
|
uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
|
|
|
|
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
|
|
|
|
|
auto task_num = GetHcomTaskNum(cur_cnode_ptr);
|
|
|
|
|
|
|
|
|
|
uint32_t cur_hcom_stream_id;
|
|
|
|
|
if (new_graph) {
|
|
|
|
|
cur_hcom_stream_id = resource_manager.ApplyNewStream();
|
|
|
|
@ -585,15 +671,15 @@ uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, b
|
|
|
|
|
auto it = hcom_stream_map_.find(cur_hcom_stream_id);
|
|
|
|
|
if (it == hcom_stream_map_.end()) {
|
|
|
|
|
AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
|
|
|
|
|
hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1));
|
|
|
|
|
hcom_stream_map_.emplace(cur_hcom_stream_id, task_num);
|
|
|
|
|
} else {
|
|
|
|
|
if (it->second < kHcomMaxTask) {
|
|
|
|
|
if (it->second <= kMaxTaskNumPerStream - task_num) {
|
|
|
|
|
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
|
|
|
|
|
it->second++;
|
|
|
|
|
it->second += task_num;
|
|
|
|
|
} else {
|
|
|
|
|
cur_hcom_stream_id = resource_manager.ApplyNewStream();
|
|
|
|
|
AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
|
|
|
|
|
hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1));
|
|
|
|
|
hcom_stream_map_.emplace(cur_hcom_stream_id, task_num);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return cur_hcom_stream_id;
|
|
|
|
@ -646,7 +732,7 @@ uint32_t AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode
|
|
|
|
|
AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get());
|
|
|
|
|
independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1));
|
|
|
|
|
} else {
|
|
|
|
|
if (it->second < kCommonMaxTask) {
|
|
|
|
|
if (it->second < kMaxCommonNodeNumPerStream) {
|
|
|
|
|
AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
|
|
|
|
|
it->second++;
|
|
|
|
|
} else {
|
|
|
|
@ -956,7 +1042,7 @@ void AscendStreamAssign::InsertStreamActiveForIndependent(const NotNull<KernelGr
|
|
|
|
|
std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
if (node_num == kCommonMaxTask) {
|
|
|
|
|
if (node_num == kMaxCommonNodeNumPerStream) {
|
|
|
|
|
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
|
|
|
|
|
// 1.set stream id
|
|
|
|
|
AnfAlgo::SetStreamId(cur_stream_id, active_ptr.get());
|
|
|
|
@ -1226,7 +1312,7 @@ void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGrap
|
|
|
|
|
vector<CNodePtr> AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
|
|
|
|
|
const CNodePtr &cur_cnode_ptr) {
|
|
|
|
|
auto cnode_ptr_list = graph_ptr->execution_order();
|
|
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup);
|
|
|
|
|
auto group_name = GetHcomGroup(cur_cnode_ptr);
|
|
|
|
|
auto input_cnodes = GetInputKernels(cur_cnode_ptr);
|
|
|
|
|
if (input_cnodes.empty()) {
|
|
|
|
|
return {};
|
|
|
|
@ -1256,7 +1342,7 @@ vector<CNodePtr> AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraph
|
|
|
|
|
CNodePtr max_common_cnode = nullptr;
|
|
|
|
|
for (const auto &item : result) {
|
|
|
|
|
if (IsHcom(item.second.first)) {
|
|
|
|
|
auto cur_group = AnfAlgo::GetNodeAttr<std::string>(item.second.first, kAttrGroup);
|
|
|
|
|
auto cur_group = GetHcomGroup(item.second.first);
|
|
|
|
|
if (cur_group == group_name) {
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
@ -1368,10 +1454,7 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode)) {
|
|
|
|
|
MS_LOG_EXCEPTION << "hcom cnode " << cur_cnode->DebugString() << " has no group attr";
|
|
|
|
|
}
|
|
|
|
|
auto group_name = AnfAlgo::GetNodeAttr<std::string>(cur_cnode, kAttrGroup);
|
|
|
|
|
auto group_name = GetHcomGroup(cur_cnode);
|
|
|
|
|
MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
|
|
|
|
|
<< "; stream id:" << cur_stream_id;
|
|
|
|
|
if (group_name != group) {
|
|
|
|
|