|
|
|
|
@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
auto dim0 = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto dim1 = ctx.Input<Tensor>("b")->dims();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
|
|
|
|
|
PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
|
|
|
|
|
PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
|
|
|
|
|
PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1");
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto b_dims = ctx.Input<Tensor>("b")->dims();
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
x_dims.size(), b_dims.size(),
|
|
|
|
|
"The rank of input `X` must be larger than the one of input `b`.");
|
|
|
|
|
|
|
|
|
|
int num_row_dims = b_dims.size();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(
|
|
|
|
|
x_dims, x_dims.size() - num_row_dims, x_dims.size()),
|
|
|
|
|
b_dims, "The width of two operands must be same");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1");
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize(x_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
auto dims0 = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto dims1 = ctx.Input<Tensor>("b")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto b_dims = ctx.Input<Tensor>("b")->dims();
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
x_dims.size(), b_dims.size(),
|
|
|
|
|
"The rank of input `X` must be larger than the one of input `b`.");
|
|
|
|
|
|
|
|
|
|
int num_row_dims = b_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(
|
|
|
|
|
x_dims, x_dims.size() - num_row_dims, x_dims.size()),
|
|
|
|
|
b_dims, "The width of two operands must be same");
|
|
|
|
|
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *db = ctx.Output<Tensor>(framework::GradVarName("b"));
|
|
|
|
|
if (dx) dx->Resize(dims0);
|
|
|
|
|
if (db) db->Resize(dims1);
|
|
|
|
|
if (dx) dx->Resize(x_dims);
|
|
|
|
|
if (db) db->Resize(b_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|