|
|
|
@ -1,3 +1,14 @@
|
|
|
|
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/nccl/nccl_ops.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -9,54 +20,27 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// allreduce do nothing in infershape
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
ctx.InputVar("X"),
|
|
|
|
|
" Input(X) of AllReduce op input should not be NULL");
|
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
|
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
|
|
|
|
|
PADDLE_ENFORCE(ins.size() == outs.size(),
|
|
|
|
|
"Input(X) and Output(Out) must have same size");
|
|
|
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
|
|
|
outs[i]->Resize(ins[i]->dims());
|
|
|
|
|
}
|
|
|
|
|
std::string reduction = ctx.Attr<std::string>("reduction");
|
|
|
|
|
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
|
|
|
|
|
reduction == "ncclMin" || reduction == "ncclMax"),
|
|
|
|
|
"invalid reduction!");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
" Input(X) of AllReduce op input should not be NULL");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
" Input(X) of AllReduce op input should not be NULL");
|
|
|
|
|
|
|
|
|
|
// BcastSendOp
|
|
|
|
|
template <typename T>
|
|
|
|
|
class NCCLBcastSendOp final : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
ctx.InputVar("X"),
|
|
|
|
|
" Input(X) of BcastSend op input should not be NULL");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
auto x_dims = ctx->GetInputsDim("X");
|
|
|
|
|
|
|
|
|
|
// BcastRecvOp
|
|
|
|
|
template <typename T>
|
|
|
|
|
class NCCLBcastRecvOp final : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
|
|
|
|
|
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
|
|
|
|
|
reduction == "ncclMin" || reduction == "ncclMax"),
|
|
|
|
|
"invalid reduction.");
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
ctx.OutputVar("Out"),
|
|
|
|
|
" Input(X) of BcastRecv op input should not be NULL");
|
|
|
|
|
ctx->SetOutputsDim("Out", x_dims);
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// AllreduceOp
|
|
|
|
|
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLAllReduceOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
@ -71,7 +55,9 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// BcastSendOp
|
|
|
|
|
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLAllReduceOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
@ -82,7 +68,9 @@ class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// BcastRecvOp
|
|
|
|
|
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLAllReduceOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
@ -93,5 +81,9 @@ class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // operators
|
|
|
|
|
} // paddle
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
|
|
|
|
|
ops::NCCLAllReduceOpMaker);
|
|
|
|
|