|
|
|
@ -14,7 +14,33 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
// AllreduceOp
|
|
|
|
|
// NCCLinitOp
|
|
|
|
|
class NCCLInitOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Communicator"),
|
|
|
|
|
" Input(X) of AllReduce op input should not be NULL");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLInitOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddAttr<std::vector<int>>("gpus", "gpu id lists");
|
|
|
|
|
AddOutput("Communicator",
|
|
|
|
|
"Create Communicator for communicating between gpus");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
create communicator.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// AllReduceOp
|
|
|
|
|
class NCCLAllReduceOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -23,6 +49,9 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
" Input(X) of AllReduce op input should not be NULL");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("Communicator"),
|
|
|
|
|
" Input(Communicator) of AllReduce op input should not be NULL");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
" Input(X) of AllReduce op input should not be NULL");
|
|
|
|
|
|
|
|
|
@ -45,6 +74,7 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The input of AllReduce op");
|
|
|
|
|
AddInput("Communicator", "Communicator for communicating between gpus");
|
|
|
|
|
AddOutput("Out", "The output of AllReduce op");
|
|
|
|
|
AddAttr<std::string>("reduction",
|
|
|
|
|
"{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}.");
|
|
|
|
@ -55,31 +85,31 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// BcastSendOp
|
|
|
|
|
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLAllReduceOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The input of BcastSend op");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
BcastSend the tensors.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
// // BcastSendOp
|
|
|
|
|
// class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
// public:
|
|
|
|
|
// NCCLAllReduceOpMaker(framework::OpProto *proto,
|
|
|
|
|
// framework::OpAttrChecker *op_checker)
|
|
|
|
|
// : OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
// AddInput("X", "The input of BcastSend op");
|
|
|
|
|
// AddComment(R"DOC(
|
|
|
|
|
// BcastSend the tensors.
|
|
|
|
|
// )DOC");
|
|
|
|
|
// }
|
|
|
|
|
// };
|
|
|
|
|
|
|
|
|
|
// BcastRecvOp
|
|
|
|
|
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLAllReduceOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddOutput("Out", "The output of BcastRecv op");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
BcastRecv the tensors.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
// // BcastRecvOp
|
|
|
|
|
// class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
// public:
|
|
|
|
|
// NCCLAllReduceOpMaker(framework::OpProto *proto,
|
|
|
|
|
// framework::OpAttrChecker *op_checker)
|
|
|
|
|
// : OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
// AddOutput("Out", "The output of BcastRecv op");
|
|
|
|
|
// AddComment(R"DOC(
|
|
|
|
|
// BcastRecv the tensors.
|
|
|
|
|
// )DOC");
|
|
|
|
|
// }
|
|
|
|
|
// };
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|