|
|
|
@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
|
|
|
|
|
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
|
|
|
|
|
comm->comms_[idx], stream));
|
|
|
|
|
comm->comms().at(idx), stream));
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|
|
|
|
|
|
VLOG(1) << "gpu : "
|
|
|
|
@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::hash<std::string> hasher;
|
|
|
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
|
|
|
if (root == platform::kInvalidGPUId) {
|
|
|
|
|
root = hasher(ins_names[i]) % comm->comms_.size();
|
|
|
|
|
root = hasher(ins_names[i]) % comm->comms().size();
|
|
|
|
|
}
|
|
|
|
|
T* recvbuffer = nullptr;
|
|
|
|
|
if (root == gpu_id) {
|
|
|
|
@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclReduce(
|
|
|
|
|
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
|
|
|
|
|
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx],
|
|
|
|
|
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx),
|
|
|
|
|
stream));
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|
|
|
|
|
@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
|
|
|
|
|
VLOG(1) << " before ncclBcast";
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclBcast(
|
|
|
|
|
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
|
|
|
|
|
root, comm->comms_[idx], stream));
|
|
|
|
|
root, comm->comms().at(idx), stream));
|
|
|
|
|
VLOG(1) << " after ncclBcast";
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|
|
|
|
|
@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclBcast(
|
|
|
|
|
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
|
|
|
|
|
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream));
|
|
|
|
|
NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream));
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|
|
|
|
|
|
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
|
|
|
|
|