|
|
|
@ -16,7 +16,7 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
class RowWiseAddOp : public OperatorWithKernel {
|
|
|
|
|
class RowwiseAddOp : public OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx.InputSize() == 2UL,
|
|
|
|
@ -32,9 +32,9 @@ protected:
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
class RowwiseAddOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
RowwiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The left input of row-wise add op, must be matrix");
|
|
|
|
|
AddInput("b", "The right input of row-wise add op, must be vector");
|
|
|
|
@ -46,13 +46,13 @@ for i in xrange(X.shape[0]):
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
class RowWiseAddGradOp : public OperatorWithKernel {
|
|
|
|
|
class RowwiseAddGradOp : public OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx.InputSize() == 4UL,
|
|
|
|
|
"RowWiseAddGrad inputs is I, O, OG, size must be 4");
|
|
|
|
|
"RowwiseAddGrad inputs is I, O, OG, size must be 4");
|
|
|
|
|
PADDLE_ENFORCE(ctx.OutputSize() == 2,
|
|
|
|
|
"RowWiseAddGrad output is IG, size must be 2");
|
|
|
|
|
"RowwiseAddGrad output is IG, size must be 2");
|
|
|
|
|
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
|
|
|
|
|
ctx.Output<Tensor>(1)->Resize(ctx.Input<Tensor>(1)->dims());
|
|
|
|
|
}
|
|
|
|
@ -61,10 +61,10 @@ protected:
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker);
|
|
|
|
|
REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(rowwise_add,
|
|
|
|
|
ops::RowWiseAddKernel<ops::CPUPlace, float>);
|
|
|
|
|
ops::RowwiseAddKernel<ops::CPUPlace, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowWiseAddGradOp);
|
|
|
|
|
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowwiseAddGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(rowwise_add_grad,
|
|
|
|
|
ops::RowWiseAddGradKernel<ops::CPUPlace, float>);
|
|
|
|
|
ops::RowwiseAddGradKernel<ops::CPUPlace, float>);
|
|
|
|
|