|
|
|
@ -27,15 +27,6 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace parallel {
|
|
|
|
|
DeviceManagerPtr g_device_manager = nullptr;
|
|
|
|
|
|
|
|
|
|
Stage::Stage(const std::vector<mindspore::parallel::Device> &devices, int64_t num, int64_t rank)
|
|
|
|
|
: devices_(devices), number_(num), rank_(rank) {
|
|
|
|
|
gm_ = GroupManager();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NOTE: '-1' indicates ERROR
|
|
|
|
|
int64_t Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); }
|
|
|
|
|
|
|
|
|
|
bool InitDevice(int64_t device_num, int64_t global_rank, const std::string &backend,
|
|
|
|
|
const std::vector<int64_t> &stage) {
|
|
|
|
|
if (device_num <= 0) {
|
|
|
|
@ -143,36 +134,23 @@ std::shared_ptr<Device> GetListMemberByIndex(size_t index, const std::vector<std
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// E.g. devices = [4, 5, 2, 1, 7, 8, 10], stage_map = [4, 3],
|
|
|
|
|
// therefore the stage_devices_ = [[4, 5, 2, 1], [7, 8, 10]].
|
|
|
|
|
// E.g. devices = [0, 1, 2, 3, 4, 5, 6, 7], stage_map = [4, 4],
|
|
|
|
|
// therefore the stage_devices_ = [[0, 1, 2, 3], [4, 5, 6, 7]].
|
|
|
|
|
Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank, const RankList &stage_map,
|
|
|
|
|
const std::string &backend) {
|
|
|
|
|
auto dev_it = devices.begin();
|
|
|
|
|
auto stage_it = stage_map.begin();
|
|
|
|
|
int64_t sum = 0;
|
|
|
|
|
|
|
|
|
|
if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid backend: " << backend;
|
|
|
|
|
return Status::FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (; stage_it != stage_map.end(); ++stage_it) {
|
|
|
|
|
sum += (*stage_it);
|
|
|
|
|
}
|
|
|
|
|
if (LongToSize(sum) != devices.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "The number of 'devices' in the list is not equal to the mentioned "
|
|
|
|
|
<< "size of 'stage_map'";
|
|
|
|
|
return Status::FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (; dev_it != devices.end(); ++dev_it) {
|
|
|
|
|
std::shared_ptr<Device> one = std::make_shared<Device>(*dev_it);
|
|
|
|
|
for (auto &dev : devices) {
|
|
|
|
|
std::shared_ptr<Device> one = std::make_shared<Device>(dev);
|
|
|
|
|
devices_.push_back(one);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t global_index = 0;
|
|
|
|
|
for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) {
|
|
|
|
|
int64_t num_device = *stage_it;
|
|
|
|
|
for (auto &stage : stage_map) {
|
|
|
|
|
int64_t num_device = stage;
|
|
|
|
|
if (num_device > MAX_DEVICE_NUM) {
|
|
|
|
|
MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM;
|
|
|
|
|
return Status::FAILED;
|
|
|
|
@ -189,29 +167,14 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
|
|
|
|
|
stage_devices_.push_back(curr_dev_list);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
global_index = 0;
|
|
|
|
|
for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) {
|
|
|
|
|
int64_t num_device = *stage_it;
|
|
|
|
|
if (num_device > MAX_DEVICE_NUM) {
|
|
|
|
|
MS_LOG(ERROR) << "The number of 'devices' in a stage must be less than " << MAX_DEVICE_NUM;
|
|
|
|
|
return Status::FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (num_device <= 0) {
|
|
|
|
|
MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive";
|
|
|
|
|
return Status::FAILED;
|
|
|
|
|
}
|
|
|
|
|
std::vector<Device> curr_dev_list;
|
|
|
|
|
for (int64_t i = 0; i < num_device; ++i) {
|
|
|
|
|
curr_dev_list.push_back(*GetListMemberByIndex(global_index, devices_));
|
|
|
|
|
global_index++;
|
|
|
|
|
}
|
|
|
|
|
std::shared_ptr<Stage> new_stage = std::make_shared<Stage>(curr_dev_list);
|
|
|
|
|
stages_.push_back(new_stage);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Device> dev = std::make_shared<Device>(global_device_rank);
|
|
|
|
|
device_ = dev;
|
|
|
|
|
|
|
|
|
|
set_global_rank(global_device_rank);
|
|
|
|
|
set_stage_num(static_cast<const int64_t>(stage_map.size()));
|
|
|
|
|
int64_t stage_id = global_device_rank / static_cast<const int64_t>(devices.size() / stage_map.size());
|
|
|
|
|
set_stage_id(stage_id);
|
|
|
|
|
|
|
|
|
|
backend_ = backend;
|
|
|
|
|
|
|
|
|
|
if (backend == HCCL_BACKEND) {
|
|
|
|
@ -221,25 +184,11 @@ Status DeviceManager::Init(const RankList &devices, int64_t global_device_rank,
|
|
|
|
|
} else {
|
|
|
|
|
gm_.set_world_group(UNDEFINED_WORLD_GROUP);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "The device num: " << devices.size() << "rank id: " << global_device_rank
|
|
|
|
|
<< "the backend: " << backend;
|
|
|
|
|
MS_LOG(INFO) << "The device num: " << devices.size() << ", rank id: " << global_device_rank
|
|
|
|
|
<< ", the backend: " << backend << ", the stage num: " << stage_num() << ", the stage id: " << stage_id;
|
|
|
|
|
return Status::SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Stage> DeviceManager::GetStageById(int64_t stage_id) {
|
|
|
|
|
std::shared_ptr<Stage> res;
|
|
|
|
|
if (LongToSize(stage_id) >= stages_.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "the 'stage_id': " << stage_id << ", is out of the scope of 'stage_devices_': " << stages_.size();
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
int64_t index = 0;
|
|
|
|
|
for (auto &stage : stages_) {
|
|
|
|
|
if (index == stage_id) return stage;
|
|
|
|
|
index++;
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RankList DeviceManager::GetDeviceListByStageId(int64_t stage_id) const {
|
|
|
|
|
if (LongToSize(stage_id) >= stage_devices_.size())
|
|
|
|
|
MS_LOG(ERROR) << "the 'stage_id': " << stage_id
|
|
|
|
|