|
|
|
@ -9,26 +9,30 @@
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/nccl_op.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/operators/nccl/nccl_gpu_common.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
// NCCLinitOp
|
|
|
|
|
class NCCLInitOp : public framework::OperatorWithKernel {
|
|
|
|
|
class NCCLInitOp : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Communicator"),
|
|
|
|
|
" Output(Communicator) of ncclInitOp should not be NULL");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::DataType IndicateDataType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
return static_cast<framework::DataType>(ctx.Attr<int>("data_type"));
|
|
|
|
|
NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
|
const platform::DeviceContext &dev_ctx) const override {
|
|
|
|
|
const auto &name = Output("Communicator");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
|
|
|
|
|
"Can not find variable '%s' in the scope.", name);
|
|
|
|
|
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
|
|
|
|
|
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
|
|
|
|
|
platform::Communicator *comm =
|
|
|
|
|
scope.FindVar(name)->GetMutable<platform::Communicator>();
|
|
|
|
|
comm->InitAll(gpus);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -188,13 +192,14 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
|
|
|
|
|
ops::NCCLAllReduceOpMaker);
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclInit, ops::NCCLInitOp, ops::NCCLInitOpMaker);
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastSend, ops::NCCLBcastSendOp,
|
|
|
|
|
ops::NCCLBcastSendOpMaker);
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclBcastRecv, ops::NCCLBcastRecvOp,
|
|
|
|
|
ops::NCCLBcastRecvOpMaker);
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp,
|
|
|
|
|
ops::NCCLReduceOpMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(ncclInit, ops::NCCLInitKernel<float>);
|
|
|
|
|