|
|
|
@ -100,8 +100,8 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// BcastSendOp
|
|
|
|
|
class NCCLBcastSendOp : public framework::OperatorWithKernel {
|
|
|
|
|
// BcastOp
|
|
|
|
|
class NCCLBcastOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
@ -111,20 +111,12 @@ class NCCLBcastSendOp : public framework::OperatorWithKernel {
|
|
|
|
|
" Input(X) of Bcast op input should not be NULL");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
|
|
|
|
|
" Input(Communicator) of Bcast op input should not be NULL");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// BcastRecvOp
|
|
|
|
|
class NCCLBcastRecvOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
|
|
|
|
|
" Input(Communicator) of Bcast op input should not be NULL");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
" Output(Out) of Bcast op output should not be NULL");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputsDim("X");
|
|
|
|
|
ctx->SetOutputsDim("Out", x_dims);
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -146,52 +138,41 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// BcastSend should be in the root
|
|
|
|
|
// BcastSendOp
|
|
|
|
|
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
// ReduceOp
|
|
|
|
|
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLBcastSendOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
NCCLReduceOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The input of BcastSend op");
|
|
|
|
|
AddInput("X", "The input of Reduce op");
|
|
|
|
|
AddInput("Communicator", "Communicator for communicating between gpus");
|
|
|
|
|
AddAttr<int>("root", "root gpu of Bcast");
|
|
|
|
|
AddOutput("Out", "The output of Reduce op");
|
|
|
|
|
AddAttr<int>("root",
|
|
|
|
|
"root gpu of the parameter. if not set(-1). hashed by name.")
|
|
|
|
|
.SetDefault(-1);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Bcast the tensors.
|
|
|
|
|
)DOC");
|
|
|
|
|
Reduce the tensors)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// BcastOp
|
|
|
|
|
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLBcastRecvOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
NCCLBcastOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The input of BcastSend op");
|
|
|
|
|
AddInput("Communicator", "Communicator for communicating between gpus");
|
|
|
|
|
AddAttr<int>("root", "root gpu of BcastRecv");
|
|
|
|
|
AddOutput("Out", "The output of Bcast");
|
|
|
|
|
AddAttr<int>("root",
|
|
|
|
|
"root gpu of the parameter. if not set(-1). hashed by name.")
|
|
|
|
|
.SetDefault(-1);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Bcast the tensors.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// BcastRecvOp
|
|
|
|
|
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLReduceOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The input of Reduce op");
|
|
|
|
|
AddInput("Communicator", "Communicator for communicating between gpus");
|
|
|
|
|
AddOutput("Out", "The output of Reduce op");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Reduce the tensors.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -201,9 +182,7 @@ REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
|
|
|
|
|
ops::NCCLAllReduceOpMaker);
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastSend, ops::NCCLBcastSendOp,
|
|
|
|
|
ops::NCCLBcastSendOpMaker);
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastRecv, ops::NCCLBcastRecvOp,
|
|
|
|
|
ops::NCCLBcastRecvOpMaker);
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclBcast, ops::NCCLBcastOp,
|
|
|
|
|
ops::NCCLBcastOpMaker);
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp,
|
|
|
|
|
ops::NCCLReduceOpMaker);
|
|
|
|
|