|
|
|
@ -49,7 +49,6 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto outs = ctx.MultiOutput<LoDTensor>("Out");
|
|
|
|
|
|
|
|
|
|
std::string reduction = ctx.Attr<std::string>("reduction");
|
|
|
|
|
|
|
|
|
|
ncclRedOp_t reduction_op_ = ncclSum;
|
|
|
|
|
|
|
|
|
|
if (reduction == "ncclMin") {
|
|
|
|
@ -101,8 +100,23 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2
|
|
|
|
|
auto outs = ctx.MultiOutput<LoDTensor>("Out");
|
|
|
|
|
int root = ctx.Attr<int>("root");
|
|
|
|
|
|
|
|
|
|
std::string reduction = ctx.Attr<std::string>("reduction");
|
|
|
|
|
ncclRedOp_t reduction_op_ = ncclSum;
|
|
|
|
|
|
|
|
|
|
if (reduction == "ncclMin") {
|
|
|
|
|
reduction_op_ = ncclMin;
|
|
|
|
|
} else if (reduction == "ncclMax") {
|
|
|
|
|
reduction_op_ = ncclMax;
|
|
|
|
|
} else if (reduction == "ncclSum") {
|
|
|
|
|
reduction_op_ = ncclSum;
|
|
|
|
|
} else if (reduction == "ncclProd") {
|
|
|
|
|
reduction_op_ = ncclProd;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int root = ctx.Attr<int>("root");
|
|
|
|
|
auto* comm = ctx.Input<Communicator>("Communicator");
|
|
|
|
|
|
|
|
|
|
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
@ -128,7 +142,8 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclReduce(
|
|
|
|
|
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
|
|
|
|
|
NCCLTypeWrapper<T>::type, ncclSum, root, comm->comms_[idx], stream));
|
|
|
|
|
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx],
|
|
|
|
|
stream));
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|
|
|
|
|
|
VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "
|
|
|
|
|