|
|
|
@ -31,8 +31,13 @@ class NCCLInitOp : public framework::OperatorBase {
|
|
|
|
|
const auto &name = Output("Communicator");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
|
|
|
|
|
"Can not find variable '%s' in the scope.", name);
|
|
|
|
|
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
|
|
|
|
|
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
|
|
|
|
|
|
|
|
|
|
int count = platform::GetCUDADeviceCount();
|
|
|
|
|
std::vector<int> gpus(count);
|
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
|
gpus[i] = i;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus.");
|
|
|
|
|
|
|
|
|
|
if (scope.FindVar(name) == nullptr) {
|
|
|
|
|
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
|
|
|
|
@ -50,11 +55,6 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddOutput("Communicator",
|
|
|
|
|
"Create Communicator for communicating between gpus");
|
|
|
|
|
AddAttr<std::vector<int>>("gpus", "(vector<int>) GPU id lists");
|
|
|
|
|
AddAttr<int>("dtype",
|
|
|
|
|
"(int, default 5 (FP32)) "
|
|
|
|
|
"Output data type")
|
|
|
|
|
.SetDefault(framework::proto::DataType::FP32);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
NCCLInit Operator.
|
|
|
|
|
|
|
|
|
@ -77,7 +77,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->HasInput("Communicator"),
|
|
|
|
|
" Input(Communicator) of AllReduce op input should not be NULL");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
" Input(X) of AllReduce op input should not be NULL");
|
|
|
|
|
" Output(Out) of AllReduce op output should not be NULL");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputsDim("X");
|
|
|
|
|
|
|
|
|
|