helinwang-patch-1
Yu Yang 8 years ago
parent 9fc0b596a9
commit d470763f6c

@ -154,6 +154,8 @@ class ParallelExecutorPrivate {
std::unordered_map<platform::Place, Scope *, platform::PlaceHash> std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
local_scopes_; local_scopes_;
std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
struct NCCLContext { struct NCCLContext {
std::unique_ptr<platform::CUDADeviceContext> ctx_; std::unique_ptr<platform::CUDADeviceContext> ctx_;
@ -246,6 +248,8 @@ ParallelExecutor::ParallelExecutor(
const ProgramDesc &startup_program, const ProgramDesc &main_program, const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope) const std::string &loss_var_name, Scope *scope)
: member_(new ParallelExecutorPrivate()) { : member_(new ParallelExecutorPrivate()) {
member_->places_ = places;
// Step 1. RunStartupProgram and Bcast the params to devs. // Step 1. RunStartupProgram and Bcast the params to devs.
Executor exe(places[0]); Executor exe(places[0]);
exe.Run(startup_program, scope, 0); exe.Run(startup_program, scope, 0);
@ -489,14 +493,14 @@ void ParallelExecutor::BCastParamsToGPUs(
platform::dynload::ncclGroupStart(); platform::dynload::ncclGroupStart();
for (auto &pair : member_->local_scopes_) { for (auto &place : member_->places_) {
auto local_scope = pair.second; auto local_scope = member_->local_scopes_[place];
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>(); auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
t->Resize(dims); t->Resize(dims);
auto &nccl_ctx = member_->GetNCCLCtx(pair.first); auto &nccl_ctx = member_->GetNCCLCtx(place);
platform::dynload::ncclBcast( platform::dynload::ncclBcast(t->mutable_data(place, main_tensor.type()),
t->mutable_data(pair.first, main_tensor.type()), numel, data_type, numel, data_type, 0, nccl_ctx.comm,
0, nccl_ctx.comm, nccl_ctx.stream()); nccl_ctx.stream());
} }
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
} }
@ -506,7 +510,7 @@ void ParallelExecutor::BCastParamsToGPUs(
for (auto &pair : member_->local_scopes_) { for (auto &pair : member_->local_scopes_) {
member_->GetNCCLCtx(pair.first).ctx_->Wait(); member_->GetNCCLCtx(pair.first).ctx_->Wait();
auto &b = pair.second->FindVar("fc_1.b_0")->Get<framework::LoDTensor>(); auto &b = pair.second->FindVar("fc_0.b_0")->Get<framework::LoDTensor>();
framework::LoDTensor cpu; framework::LoDTensor cpu;
framework::TensorCopy(b, platform::CPUPlace(), &cpu); framework::TensorCopy(b, platform::CPUPlace(), &cpu);
platform::DeviceContextPool::Instance().Get(b.place())->Wait(); platform::DeviceContextPool::Instance().Get(b.place())->Wait();

Loading…
Cancel
Save