|
|
|
@ -2,8 +2,8 @@
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
http://www.apache.org/licenseshashernless required by applicable law or agreed
|
|
|
|
|
to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
@ -27,25 +27,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto ins = ctx.MultiInput<Tensor>("X");
|
|
|
|
|
auto outs = ctx.MultiOutput<Tensor>("Out");
|
|
|
|
|
std::string reduction = ctx.Attr<std::string>("reduction");
|
|
|
|
|
ncclRedOp_t op_type;
|
|
|
|
|
if (reduction == "ncclSum") {
|
|
|
|
|
op_type = ncclSum;
|
|
|
|
|
} else if (reduction == "ncclProd") {
|
|
|
|
|
op_type = ncclProd;
|
|
|
|
|
} else if (reduction == "ncclMin") {
|
|
|
|
|
op_type = ncclMin;
|
|
|
|
|
} else if (reduction == "ncclMax") {
|
|
|
|
|
op_type = ncclMax;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE(false, "reduction error.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto* comm = ctx.Input<Communicator>("Communicator");
|
|
|
|
|
|
|
|
|
|
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream();
|
|
|
|
|
|
|
|
|
|
// device id
|
|
|
|
|
int device_id =
|
|
|
|
|
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
|
|
|
|
@ -54,7 +41,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE(ncclAllReduce(
|
|
|
|
|
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, op_type,
|
|
|
|
|
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, ncclSum,
|
|
|
|
|
comm->comms_[idx], stream));
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|
}
|
|
|
|
@ -68,7 +55,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"This kernel only runs on GPU device.");
|
|
|
|
|
|
|
|
|
|
auto ins = ctx.MultiInput<Tensor>("X");
|
|
|
|
|
auto ins = ctx.MultiInput<Tensor>("X"); // x0, x1, x2
|
|
|
|
|
auto outs = ctx.MultiOutput<Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
auto* comm = ctx.Input<Communicator>("Communicator");
|
|
|
|
@ -81,14 +68,16 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
|
|
|
|
|
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
|
|
|
|
|
int idx = comm->GetCommId(device_id);
|
|
|
|
|
|
|
|
|
|
auto ins_names = ctx.Inputs("X");
|
|
|
|
|
std::hash<std::string> hasher;
|
|
|
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
|
|
|
int root = std::hash() % comm->comms_.size();
|
|
|
|
|
int root = hasher(ins_names[i]) % comm->comms_.size();
|
|
|
|
|
T* recvbuffer = nullptr;
|
|
|
|
|
if (root == device_id) {
|
|
|
|
|
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(ncclReduce(ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
|
|
|
|
|
NCCLTypeWrapper<T>::type, root, ncclSum,
|
|
|
|
|
NCCLTypeWrapper<T>::type, ncclSum, root,
|
|
|
|
|
comm->comms_[idx], stream));
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|
}
|
|
|
|
@ -124,7 +113,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else {
|
|
|
|
|
auto outs = ctx.MultiOutput<Tensor>("Out");
|
|
|
|
|
for (size_t i = 0; i < outs.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE(ncclBcast((void*)outs[i]->mutable_data<T>(),
|
|
|
|
|
PADDLE_ENFORCE(ncclBcast(outs[i]->mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
outs[i]->numel(), NCCLTypeWrapper<T>::type,
|
|
|
|
|
root, comm->comms_[idx], stream));
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
|
|
|
|