|
|
|
@ -171,8 +171,9 @@ inline std::string GetHierarchicalExterNCCLVarName(size_t pos) {
|
|
|
|
|
return string::Sprintf("Hierarchical_exter_%s_%d", NCCL_ID_VARNAME,
|
|
|
|
|
static_cast<int>(pos));
|
|
|
|
|
}
|
|
|
|
|
inline std::string GetHierarchicalInterNCCLVarName() {
|
|
|
|
|
return string::Sprintf("Hierarchical_inter_%s", NCCL_ID_VARNAME);
|
|
|
|
|
inline std::string GetHierarchicalInterNCCLVarName(size_t pos) {
|
|
|
|
|
return string::Sprintf("Hierarchical_inter_%s_%d", NCCL_ID_VARNAME,
|
|
|
|
|
static_cast<int>(pos));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class MultiNCCLContextMap {
|
|
|
|
@ -224,8 +225,8 @@ class MultiNCCLContextMap {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InitHierarchicalCtxs(const std::vector<platform::Place> &places,
|
|
|
|
|
ncclUniqueId *inter_nccl_id,
|
|
|
|
|
const std::vector<ncclUniqueId *> &exter_nccl_id,
|
|
|
|
|
const std::vector<ncclUniqueId *> &inter_nccl_ids,
|
|
|
|
|
const std::vector<ncclUniqueId *> &exter_nccl_ids,
|
|
|
|
|
size_t trainers_num, size_t trainer_id,
|
|
|
|
|
size_t inter_trainers_num,
|
|
|
|
|
size_t exter_trainers_num) {
|
|
|
|
@ -238,11 +239,14 @@ class MultiNCCLContextMap {
|
|
|
|
|
inter_trainers_num);
|
|
|
|
|
|
|
|
|
|
int inter_trainer_id = trainer_id % inter_trainers_num;
|
|
|
|
|
VLOG(1) << "init inter_trainer_id:" << inter_trainer_id;
|
|
|
|
|
auto local = new NCCLContextMap(places, inter_nccl_id, inter_trainers_num,
|
|
|
|
|
inter_trainer_id);
|
|
|
|
|
for (size_t i = 0; i < inter_nccl_ids.size(); i++) {
|
|
|
|
|
VLOG(1) << "init inter_trainer_id:" << inter_trainer_id
|
|
|
|
|
<< ", comm no:" << i;
|
|
|
|
|
auto local = new NCCLContextMap(places, inter_nccl_ids[i],
|
|
|
|
|
inter_trainers_num, inter_trainer_id);
|
|
|
|
|
|
|
|
|
|
h_inter_ctxs_.emplace_back(local);
|
|
|
|
|
h_inter_ctxs_.emplace_back(local);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int exter_trainer_id = -1;
|
|
|
|
|
if (trainer_id % inter_trainers_num == 0) {
|
|
|
|
@ -250,8 +254,8 @@ class MultiNCCLContextMap {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (exter_trainer_id >= 0) {
|
|
|
|
|
for (size_t i = 0; i < exter_nccl_id.size(); i++) {
|
|
|
|
|
auto ex = new NCCLContextMap(places, exter_nccl_id[i],
|
|
|
|
|
for (size_t i = 0; i < exter_nccl_ids.size(); i++) {
|
|
|
|
|
auto ex = new NCCLContextMap(places, exter_nccl_ids[i],
|
|
|
|
|
exter_trainers_num, exter_trainer_id);
|
|
|
|
|
VLOG(1) << "init exter_trainer_id:" << exter_trainer_id
|
|
|
|
|
<< ", comm no:" << i;
|
|
|
|
|