|
|
|
@ -17,7 +17,9 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
class RowWiseAddOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::Tensor;
|
|
|
|
|
|
|
|
|
|
class RowwiseAddOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
@ -34,9 +36,9 @@ class RowWiseAddOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
RowWiseAddOpMaker(framework::OpProto *proto,
|
|
|
|
|
RowwiseAddOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The left input of row-wise add op, must be matrix");
|
|
|
|
@ -49,12 +51,32 @@ for i in xrange(X.shape[0]):
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
class RowwiseAddGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
auto dims0 = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto dims1 = ctx.Input<Tensor>("b")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
|
|
|
|
|
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
|
|
|
|
|
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp,
|
|
|
|
|
ops::RowWiseAddOpMaker);
|
|
|
|
|
REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker,
|
|
|
|
|
rowwise_add_grad, ops::RowwiseAddGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
rowwise_add, ops::RowwiseAddKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
rowwise_add, ops::RowWiseAddKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
rowwise_add_grad,
|
|
|
|
|
ops::RowwiseAddGradKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|