|
|
|
@ -140,7 +140,12 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
|
|
|
|
|
const std::string &backend) {
|
|
|
|
|
if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid backend: " << backend;
|
|
|
|
|
return Status::FAILED;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (stage_map.empty() || devices.empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "The size of stage_map and devices must be positive";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &dev : devices) {
|
|
|
|
@ -153,11 +158,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
|
|
|
|
|
int64_t num_device = stage;
|
|
|
|
|
if (num_device > MAX_DEVICE_NUM) {
|
|
|
|
|
MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM;
|
|
|
|
|
return Status::FAILED;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (num_device <= 0) {
|
|
|
|
|
MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive";
|
|
|
|
|
return Status::FAILED;
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
RankList curr_dev_list;
|
|
|
|
|
for (int64_t i = 0; i < num_device; ++i) {
|
|
|
|
@ -170,10 +175,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
|
|
|
|
|
std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank);
|
|
|
|
|
device_ = dev;
|
|
|
|
|
|
|
|
|
|
set_global_rank(global_device_rank);
|
|
|
|
|
set_stage_num(static_cast<const int64_t>(stage_map.size()));
|
|
|
|
|
int64_t stage_id = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
|
|
|
|
|
set_stage_id(stage_id);
|
|
|
|
|
global_rank_ = global_device_rank;
|
|
|
|
|
stage_num_ = static_cast<const int64_t>(stage_map.size());
|
|
|
|
|
stage_id_ = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
|
|
|
|
|
rank_index_in_stage_ = global_rank_ - stage_id_ * (static_cast<const int64_t>(devices.size()) / stage_num_);
|
|
|
|
|
stage_device_num_ = static_cast<const int64_t>(devices.size()) / stage_num_;
|
|
|
|
|
|
|
|
|
|
backend_ = backend;
|
|
|
|
|
|
|
|
|
@ -185,10 +191,13 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
|
|
|
|
|
gm_.set_world_group(UNDEFINED_WORLD_GROUP);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank
|
|
|
|
|
<< ", the backend: " << backend << ", the stage num: " << stage_num() << ", the stage id: " << stage_id;
|
|
|
|
|
return Status::SUCCESS;
|
|
|
|
|
<< ", the backend: " << backend << ", the stage num: " << stage_num_ << ", the stage id: " << stage_id_
|
|
|
|
|
<< ", the rank index in stage is: " << rank_index_in_stage_;
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RankList DeviceManager::GetDeviceListInThisStage() const { return GetDeviceListByStageId(stage_id_); }
|
|
|
|
|
|
|
|
|
|
RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
|
|
|
|
|
if (LongToSize(stage_id) >= stage_devices_.size())
|
|
|
|
|
MS_LOG(ERROR) << "the 'stage_id': " << stage_id
|
|
|
|
@ -204,49 +213,6 @@ RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RankList DeviceManager::global_device_list(int64_t stage_id, int64_t rank, int64_t split_num) const {
|
|
|
|
|
RankList res;
|
|
|
|
|
if (split_num <= 0) {
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
if (LongToSize(stage_id) >= stage_devices_.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "the 'stage_id': " << stage_id
|
|
|
|
|
<< ", is out of the scope of 'stage_devices_': " << stage_devices_.size();
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RankList global_list = GetDeviceListByStageId(stage_id);
|
|
|
|
|
if (global_list.size() % LongToSize(split_num)) {
|
|
|
|
|
MS_LOG(ERROR) << "dev list size(" << global_list.size() << ") can not be divisible by split num: " << stage_id;
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> dev_list;
|
|
|
|
|
(void)std::copy(global_list.begin(), global_list.end(), std::back_inserter(dev_list));
|
|
|
|
|
|
|
|
|
|
size_t index = 0;
|
|
|
|
|
size_t slice_size = dev_list.size() / LongToSize(split_num);
|
|
|
|
|
for (int64_t i = 0; i < split_num; ++i) {
|
|
|
|
|
bool found = false;
|
|
|
|
|
index = slice_size * LongToSize(i);
|
|
|
|
|
for (size_t j = 0; j < slice_size; ++j) {
|
|
|
|
|
if (dev_list[index + j] == rank) {
|
|
|
|
|
found = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (found) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t k = 0; k < slice_size; ++k) {
|
|
|
|
|
res.push_back(dev_list[index + k]);
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Device DeviceManager::CreateNewDeviceByRank(int64_t rank) const { return Device(rank); }
|
|
|
|
|
|
|
|
|
|
std::vector<Device> DeviceManager::CreateDeviceListByRankList(RankList ranks) {
|
|
|
|
|