"add reduce hash function"

fix-typo
Dong Zhihong 7 years ago
parent 423d7438a1
commit ec47565c23

@ -289,6 +289,15 @@ class ExecutionContext {
return device_context_;
}
//! Get a input which has multiple variables.
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
//! Get an output which has multiple variables.
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
#ifdef PADDLE_WITH_CUDA
const platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));

@ -81,9 +81,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
" Input(Communicator) of Reduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of Reduce op input should not be NULL");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
@ -137,8 +134,8 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
// BcastSendOp
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllBcastSendOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
NCCLBcastSendOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op");
AddInput("Communicator", "Communicator for communicating between gpus");
@ -152,8 +149,8 @@ class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
// BcastOp
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllBcastRecvOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
NCCLBcastRecvOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Communicator", "Communicator for communicating between gpus");
AddAttr<int>("root", "root gpu of BcastRecv");

@ -2,8 +2,8 @@
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
http://www.apache.org/licenseshashernless required by applicable law or agreed
to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
@ -27,25 +27,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<Tensor>("X");
auto outs = ctx.MultiOutput<Tensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t op_type;
if (reduction == "ncclSum") {
op_type = ncclSum;
} else if (reduction == "ncclProd") {
op_type = ncclProd;
} else if (reduction == "ncclMin") {
op_type = ncclMin;
} else if (reduction == "ncclMax") {
op_type = ncclMax;
} else {
PADDLE_ENFORCE(false, "reduction error.");
}
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int device_id =
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
@ -54,7 +41,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE(ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, op_type,
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, ncclSum,
comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
}
@ -68,7 +55,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto ins = ctx.MultiInput<Tensor>("X");
auto ins = ctx.MultiInput<Tensor>("X"); // x0, x1, x2
auto outs = ctx.MultiOutput<Tensor>("Out");
auto* comm = ctx.Input<Communicator>("Communicator");
@ -81,14 +68,16 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(device_id);
auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) {
int root = std::hash() % comm->comms_.size();
int root = hasher(ins_names[i]) % comm->comms_.size();
T* recvbuffer = nullptr;
if (root == device_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
}
PADDLE_ENFORCE(ncclReduce(ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, root, ncclSum,
NCCLTypeWrapper<T>::type, ncclSum, root,
comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
}
@ -124,7 +113,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
} else {
auto outs = ctx.MultiOutput<Tensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) {
PADDLE_ENFORCE(ncclBcast((void*)outs[i]->mutable_data<T>(),
PADDLE_ENFORCE(ncclBcast(outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));

Loading…
Cancel
Save