|
|
|
@ -68,16 +68,16 @@ class MulOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
auto dim0 = ctx.Output<Tensor>(framework::GradVarName("X"))->dims();
|
|
|
|
|
auto dim1 = ctx.Output<Tensor>(framework::GradVarName("Y"))->dims();
|
|
|
|
|
auto x_dims = ctx.Output<Tensor>(framework::GradVarName("X"))->dims();
|
|
|
|
|
auto y_dims = ctx.Output<Tensor>(framework::GradVarName("Y"))->dims();
|
|
|
|
|
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
|
|
|
|
|
PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0],
|
|
|
|
|
"Out@GRAD[0] must equal to X[0] * Y[0]");
|
|
|
|
|
PADDLE_ENFORCE(dim0[1] * dim1[1] == out_dims[1],
|
|
|
|
|
"Out@GRAD shape must equal to X[1] * Y[1]");
|
|
|
|
|
PADDLE_ENFORCE(x_dims[0] == out_dims[0],
|
|
|
|
|
"Out@GRAD M X N must equal to X dims 0, M ");
|
|
|
|
|
PADDLE_ENFORCE(y_dims[1] == out_dims[1],
|
|
|
|
|
"Out@GRAD M X N must equal to Y dims 1, N ");
|
|
|
|
|
|
|
|
|
|
x_grad->Resize(dim1);
|
|
|
|
|
y_grad->Resize(dim0);
|
|
|
|
|
x_grad->Resize(x_dims);
|
|
|
|
|
y_grad->Resize(y_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|