|
|
|
@ -297,21 +297,26 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
|
|
|
|
|
op_handle->AddInput(in);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &vars = result->vars_.at(i).at(p_name);
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
auto &vars = result->vars_.at(i).at(p_name);
|
|
|
|
|
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
|
|
|
|
|
vars.emplace_back(out_var);
|
|
|
|
|
op_handle->AddOutput(out_var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::SetCommunicationContext(
|
|
|
|
|
OpHandleBase *op_handle, const platform::Place &p) const {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(
|
|
|
|
|
p, platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
|
|
|
|
@ -334,24 +339,12 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
auto &vars = result->vars_[i][og];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(
|
|
|
|
|
p, platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
VLOG(4) << "NCCL - - - " << p;
|
|
|
|
|
op_handle->DeviceContext(p)->Wait();
|
|
|
|
|
VLOG(4) << "NCCL - - - " << p << " " << op_handle->DeviceContext(p);
|
|
|
|
|
auto var = new VarHandle(vars.size() - 1, i, og, p);
|
|
|
|
|
vars.emplace_back(var);
|
|
|
|
|
op_handle->AddOutput(var);
|
|
|
|
@ -441,17 +434,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
|
|
|
|
|
auto *op_handle = result->ops_.back().get();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < places_.size(); ++i) {
|
|
|
|
|
auto &vars = result->vars_[i][og];
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (nccl_ctxs_ == nullptr) {
|
|
|
|
|
op_handle->SetDeviceContext(
|
|
|
|
|
p, platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
op_handle->SetDeviceContext(p,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p));
|
|
|
|
|
#endif
|
|
|
|
|
SetCommunicationContext(op_handle, p);
|
|
|
|
|
auto &vars = result->vars_[i][og];
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty());
|
|
|
|
|
auto &prev_grad = vars.back();
|
|
|
|
|
op_handle->AddInput(prev_grad.get());
|
|
|
|
|