|
|
|
@ -17,6 +17,8 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using framework::Tensor;
|
|
|
|
|
|
|
|
|
|
class RowwiseAddOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -50,14 +52,23 @@ for i in xrange(X.shape[0]):
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
class RowwiseAddGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx.InputSize() == 4UL,
|
|
|
|
|
"RowwiseAddGrad inputs is I, O, OG, size must be 4");
|
|
|
|
|
PADDLE_ENFORCE(ctx.OutputSize() == 2,
|
|
|
|
|
"RowwiseAddGrad output is IG, size must be 2");
|
|
|
|
|
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
|
|
|
|
|
ctx.Output<Tensor>(1)->Resize(ctx.Input<Tensor>(1)->dims());
|
|
|
|
|
// PADDLE_ENFORCE(ctx.InputSize() == 4UL,
|
|
|
|
|
// "RowwiseAddGrad inputs is I, O, OG, size must be 4");
|
|
|
|
|
// PADDLE_ENFORCE(ctx.OutputSize() == 2,
|
|
|
|
|
// "RowwiseAddGrad output is IG, size must be 2");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
auto dims0 = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto dims1 = ctx.Input<Tensor>("b")->dims();
|
|
|
|
|
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
|
|
|
|
|
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|