|
|
|
@ -35,7 +35,7 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
RowWiseAddOpMaker(framework::OpProto *proto,
|
|
|
|
|
RowwiseAddOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The left input of row-wise add op, must be matrix");
|
|
|
|
@ -48,9 +48,9 @@ for i in xrange(X.shape[0]):
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
class RowwiseAddGradOp : public OperatorWithKernel {
|
|
|
|
|
class RowwiseAddGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const InferShapeContext &ctx) const override {
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx.InputSize() == 4UL,
|
|
|
|
|
"RowwiseAddGrad inputs is I, O, OG, size must be 4");
|
|
|
|
|
PADDLE_ENFORCE(ctx.OutputSize() == 2,
|
|
|
|
|