|
|
|
@ -375,6 +375,12 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
if (this->inputs_.size() == 1) {
|
|
|
|
|
return; // No need to all reduce when GPU count = 1;
|
|
|
|
|
} else {
|
|
|
|
|
// Wait input done
|
|
|
|
|
for (auto *in : inputs_) {
|
|
|
|
|
auto &p = static_cast<VarHandle *>(in)->place_;
|
|
|
|
|
in->generated_op_->Wait(dev_ctx_[p]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
|
|
|
|
|
VLOG(3) << "Invoke NCCL AllReduce";
|
|
|
|
|
int dtype = -1;
|
|
|
|
|