|
|
|
@ -209,30 +209,9 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
|
|
|
|
|
void ParallelExecutor::BCastParamsToDevices(
|
|
|
|
|
const std::unordered_set<std::string> &vars) const {
|
|
|
|
|
// the initializing bcast, all vars would be bcast from device(0),
|
|
|
|
|
// otherwise
|
|
|
|
|
// bcast from the specified device.
|
|
|
|
|
bool initializing = member_->executor_ ? false : true;
|
|
|
|
|
// the initializing bcast, all vars would be bcast from device(0).
|
|
|
|
|
for (auto &var : vars) {
|
|
|
|
|
int var_dev_id = -1;
|
|
|
|
|
if (member_->executor_) {
|
|
|
|
|
auto &sharded_var_device =
|
|
|
|
|
member_->executor_->Graph().Get<details::ShardedVarDevice>(
|
|
|
|
|
details::kShardedVarDevice);
|
|
|
|
|
if (sharded_var_device.find(var) != sharded_var_device.end()) {
|
|
|
|
|
var_dev_id = sharded_var_device.at(var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!initializing && var_dev_id == -1) continue;
|
|
|
|
|
|
|
|
|
|
framework::Variable *main_var = nullptr;
|
|
|
|
|
if (initializing) {
|
|
|
|
|
main_var = member_->local_scopes_[0]->FindVar(var);
|
|
|
|
|
} else {
|
|
|
|
|
main_var = member_->local_scopes_[var_dev_id]->FindVar(var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
|
|
|
|
|
if (main_var == nullptr || !main_var->IsType<LoDTensor>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -248,8 +227,7 @@ void ParallelExecutor::BCastParamsToDevices(
|
|
|
|
|
auto place = member_->places_[i];
|
|
|
|
|
void *buffer;
|
|
|
|
|
|
|
|
|
|
if ((initializing && i == 0) ||
|
|
|
|
|
(!initializing && static_cast<int>(i) == var_dev_id)) {
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
buffer = const_cast<void *>(main_tensor.data<void>());
|
|
|
|
|
} else {
|
|
|
|
|
auto local_scope = member_->local_scopes_[i];
|
|
|
|
@ -266,29 +244,18 @@ void ParallelExecutor::BCastParamsToDevices(
|
|
|
|
|
platform::NCCLGroupGuard guard;
|
|
|
|
|
for (size_t i = 0; i < member_->places_.size(); ++i) {
|
|
|
|
|
auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]);
|
|
|
|
|
if (initializing) {
|
|
|
|
|
platform::dynload::ncclBcast(buffers[i], numel, data_type, 0,
|
|
|
|
|
nccl_ctx.comm_, nccl_ctx.stream());
|
|
|
|
|
} else {
|
|
|
|
|
if (var_dev_id >= 0) {
|
|
|
|
|
platform::dynload::ncclBcast(buffers[i], numel, data_type,
|
|
|
|
|
var_dev_id, nccl_ctx.comm_,
|
|
|
|
|
nccl_ctx.stream());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
member_->nccl_ctxs_->WaitAll();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Not compiled with CUDA");
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
platform::CPUPlace cpu;
|
|
|
|
|
for (size_t i = 0; i < member_->places_.size(); ++i) {
|
|
|
|
|
if ((initializing && i == 0) ||
|
|
|
|
|
(!initializing && static_cast<int>(i) == var_dev_id))
|
|
|
|
|
continue;
|
|
|
|
|
if (i == 0) continue;
|
|
|
|
|
|
|
|
|
|
auto local_scope = member_->local_scopes_[i];
|
|
|
|
|
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
|
|
|
|
|