|
|
|
@ -63,7 +63,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
auto dims0 = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto dims1 = ctx.Input<Tensor>("b")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(1, framework::product(dims1), "b dims should be 1")
|
|
|
|
|
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
|
|
|
|
|
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
|
|
|
|
|
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
|
|
|
|
|
}
|
|
|
|
|