|
|
|
@ -34,7 +34,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
|
|
|
|
|
nccl_ctxs_(ctxs) {
|
|
|
|
|
if (nccl_ctxs_) {
|
|
|
|
|
for (auto &p : places_) {
|
|
|
|
|
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
|
|
|
|
|
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -46,7 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
void AllReduceOpHandle::RunImpl() {
|
|
|
|
|
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
|
|
|
|
|
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
|
|
|
|
|
|
|
|
|
|
if (NoDummyInputSize() == 1) {
|
|
|
|
|
return; // No need to all reduce when GPU count = 1;
|
|
|
|
@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
|
|
|
|
|
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
|
|
|
|
|
auto &p = places_[i];
|
|
|
|
|
auto *var = scope.FindVar(out_var_handles[i]->name_);
|
|
|
|
|
auto *dev_ctx = dev_ctxes_[p];
|
|
|
|
|
auto *dev_ctx = dev_ctxes_.at(p);
|
|
|
|
|
|
|
|
|
|
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
|
|
|
|
|
auto &tensor_gpu = *var->GetMutable<framework::LoDTensor>();
|
|
|
|
|