|
|
|
@ -46,6 +46,17 @@ for i in xrange(X.shape[0]):
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
class RowWiseAddGradOp : public OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx.InputSize() == 4UL,
|
|
|
|
|
"RowWiseAddGrad inputs is I, O, OG, size must be 4");
|
|
|
|
|
PADDLE_ENFORCE(ctx.OutputSize() == 2,
|
|
|
|
|
"RowWiseAddGrad output is IG, size must be 2");
|
|
|
|
|
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
|
|
|
|
|
ctx.Output<Tensor>(1)->Resize(ctx.Input<Tensor>(1)->dims());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
@ -53,3 +64,7 @@ for i in xrange(X.shape[0]):
|
|
|
|
|
REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(rowwise_add,
|
|
|
|
|
ops::RowWiseAddKernel<ops::CPUPlace, float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowWiseAddGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(rowwise_add_grad,
|
|
|
|
|
ops::RowWiseAddGradKernel<ops::CPUPlace, float>);
|
|
|
|
|