|
|
|
@ -365,8 +365,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
auto &p = static_cast<VarHandle *>(in)->place_;
|
|
|
|
|
in->generated_op_->Wait(dev_ctx_[p]);
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "Before NCCL";
|
|
|
|
|
PADDLE_ENFORCE(cudaDeviceSynchronize());
|
|
|
|
|
|
|
|
|
|
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
|
|
|
|
|
int dtype = -1;
|
|
|
|
@ -395,7 +393,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
nccl_ctx.comm, nccl_ctx.stream());
|
|
|
|
|
}
|
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
|
PADDLE_ENFORCE(cudaDeviceSynchronize());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|