"add multioperator testcase"

fix-typo
Dong Zhihong 8 years ago
parent 94992a990b
commit 38d3adfeb6

@ -100,8 +100,8 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
} }
}; };
// BcastSendOp // BcastOp
class NCCLBcastSendOp : public framework::OperatorWithKernel { class NCCLBcastOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
@ -111,20 +111,12 @@ class NCCLBcastSendOp : public framework::OperatorWithKernel {
" Input(X) of Bcast op input should not be NULL"); " Input(X) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasInput("Communicator"), PADDLE_ENFORCE(ctx->HasInput("Communicator"),
" Input(Communicator) of Bcast op input should not be NULL"); " Input(Communicator) of Bcast op input should not be NULL");
}
};
// BcastRecvOp
class NCCLBcastRecvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
" Input(Communicator) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of Bcast op output should not be NULL"); " Output(Out) of Bcast op output should not be NULL");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
@ -146,52 +138,41 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
// BcastSend should be in the root // ReduceOp
// BcastSendOp class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLBcastSendOpMaker(framework::OpProto *proto, NCCLReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op"); AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
AddAttr<int>("root", "root gpu of Bcast"); AddOutput("Out", "The output of Reduce op");
AddAttr<int>("root",
"root gpu of the parameter. if not set(-1). hashed by name.")
.SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
Bcast the tensors. Reduce the tensors)DOC");
)DOC");
} }
}; };
// BcastOp // BcastOp
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLBcastRecvOpMaker(framework::OpProto *proto, NCCLBcastOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
AddAttr<int>("root", "root gpu of BcastRecv");
AddOutput("Out", "The output of Bcast"); AddOutput("Out", "The output of Bcast");
AddAttr<int>("root",
"root gpu of the parameter. if not set(-1). hashed by name.")
.SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
Bcast the tensors. Bcast the tensors.
)DOC"); )DOC");
} }
}; };
// BcastRecvOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Reduce op");
AddComment(R"DOC(
Reduce the tensors.
)DOC");
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
@ -201,9 +182,7 @@ REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp, REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
ops::NCCLAllReduceOpMaker); ops::NCCLAllReduceOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastSend, ops::NCCLBcastSendOp, REGISTER_OP_WITHOUT_GRADIENT(ncclBcast, ops::NCCLBcastOp,
ops::NCCLBcastSendOpMaker); ops::NCCLBcastOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastRecv, ops::NCCLBcastRecvOp,
ops::NCCLBcastRecvOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp, REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp,
ops::NCCLReduceOpMaker); ops::NCCLReduceOpMaker);

@ -83,6 +83,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2 auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto outs = ctx.MultiOutput<LoDTensor>("Out");
int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
@ -97,7 +98,9 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
auto ins_names = ctx.Inputs("X"); auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher; std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
int root = hasher(ins_names[i]) % comm->comms_.size(); if (root == -1) {
root = hasher(ins_names[i]) % comm->comms_.size();
}
T* recvbuffer = nullptr; T* recvbuffer = nullptr;
if (root == device_id) { if (root == device_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace()); recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
@ -135,8 +138,9 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
int device_id = int device_id =
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId(); boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(device_id); int idx = comm->GetCommId(device_id);
if (idx == root) { if (idx == root) {
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<LoDTensor>("X");
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type, (void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
@ -144,7 +148,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
} }
} else { } else {
auto outs = ctx.MultiOutput<Tensor>("Out"); auto outs = ctx.MultiOutput<LoDTensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) { for (size_t i = 0; i < outs.size(); ++i) {
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(), outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
@ -160,6 +164,5 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel<float>); REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel<float>);
REGISTER_OP_GPU_KERNEL(ncclBcastSend, ops::NCCLBcastKernel<float>); REGISTER_OP_GPU_KERNEL(ncclBcast, ops::NCCLBcastKernel<float>);
REGISTER_OP_GPU_KERNEL(ncclReduce, ops::NCCLReduceKernel<float>); REGISTER_OP_GPU_KERNEL(ncclReduce, ops::NCCLReduceKernel<float>);
REGISTER_OP_GPU_KERNEL(ncclBcastRecv, ops::NCCLBcastKernel<float>);

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save