|
|
|
@ -229,8 +229,15 @@ class ParallelExecutorPrivate {
|
|
|
|
|
|
|
|
|
|
// TODO(yy): Move this function somewhere
|
|
|
|
|
ncclDataType_t ToNCCLDataType(std::type_index type) {
|
|
|
|
|
// FIXME!!
|
|
|
|
|
return ncclFloat;
|
|
|
|
|
if (type == typeid(float)) { // NOLINT
|
|
|
|
|
return ncclFloat;
|
|
|
|
|
} else if (type == typeid(double)) { // NOLINT
|
|
|
|
|
return ncclDouble;
|
|
|
|
|
} else if (type == typeid(int)) { // NOLINT
|
|
|
|
|
return ncclInt;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Not supported");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ParallelExecutor::ParallelExecutor(
|
|
|
|
@ -479,30 +486,32 @@ void ParallelExecutor::BCastParamsToGPUs(
|
|
|
|
|
ncclDataType_t data_type = ToNCCLDataType(main_tensor.type());
|
|
|
|
|
auto &dims = main_tensor.dims();
|
|
|
|
|
size_t numel = main_tensor.numel();
|
|
|
|
|
std::vector<std::pair<void *, ParallelExecutorPrivate::NCCLContext *>>
|
|
|
|
|
mems;
|
|
|
|
|
mems.emplace_back(const_cast<void *>(main_tensor.data<void>()),
|
|
|
|
|
&member_->GetNCCLCtx(member_->main_place_));
|
|
|
|
|
|
|
|
|
|
for (auto &pair : member_->local_scopes_) {
|
|
|
|
|
if (pair.first == member_->main_place_) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
platform::dynload::ncclGroupStart();
|
|
|
|
|
|
|
|
|
|
for (auto &pair : member_->local_scopes_) {
|
|
|
|
|
auto local_scope = pair.second;
|
|
|
|
|
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
|
|
|
|
|
t->Resize(dims);
|
|
|
|
|
mems.emplace_back(t->mutable_data(pair.first, main_tensor.type()),
|
|
|
|
|
&member_->GetNCCLCtx(member_->main_place_));
|
|
|
|
|
auto &nccl_ctx = member_->GetNCCLCtx(pair.first);
|
|
|
|
|
platform::dynload::ncclBcast(
|
|
|
|
|
t->mutable_data(pair.first, main_tensor.type()), numel, data_type,
|
|
|
|
|
0, nccl_ctx.comm, nccl_ctx.stream());
|
|
|
|
|
}
|
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(yy): Invoke ncclBCast here. mems, numel, data_type. The mems[0]
|
|
|
|
|
// is the src, rests are dests.
|
|
|
|
|
for (auto &pair : member_->local_scopes_) {
|
|
|
|
|
member_->GetNCCLCtx(pair.first).ctx_->Wait();
|
|
|
|
|
|
|
|
|
|
(void)(data_type);
|
|
|
|
|
(void)(numel);
|
|
|
|
|
}
|
|
|
|
|
auto &b = pair.second->FindVar("fc_1.b_0")->Get<framework::LoDTensor>();
|
|
|
|
|
framework::LoDTensor cpu;
|
|
|
|
|
framework::TensorCopy(b, platform::CPUPlace(), &cpu);
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(b.place())->Wait();
|
|
|
|
|
LOG(INFO) << *cpu.data<float>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Not compiled with CUDA");
|
|
|
|
|
#endif
|
|
|
|
|