code refine

wangkuiyi-patch-1
chengduoZH 7 years ago
parent 5a3c8bf813
commit 2d94697a82

@ -297,21 +297,26 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_.at(i).at(p_name);
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_.at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
}
}
void MultiDevSSAGraphBuilder::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(
p, platform::DeviceContextPool::Instance().Get(p));
}
#else
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
@ -334,24 +339,12 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(
p, platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
VLOG(4) << "NCCL - - - " << p;
op_handle->DeviceContext(p)->Wait();
VLOG(4) << "NCCL - - - " << p << " " << op_handle->DeviceContext(p);
auto var = new VarHandle(vars.size() - 1, i, og, p);
vars.emplace_back(var);
op_handle->AddOutput(var);
@ -441,17 +434,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
auto *op_handle = result->ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &vars = result->vars_[i][og];
auto &p = places_[i];
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(
p, platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());

@ -111,6 +111,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
BuildStrategy strategy_;
void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const;
};
} // namespace details
} // namespace framework

Loading…
Cancel
Save