|
|
|
@ -25,27 +25,27 @@ class MulOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
auto x_dim = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto y_dim = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto y_dims = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|
int x_num_row_dims = GetAttr<int>("x_num_row_dims");
|
|
|
|
|
int y_num_row_dims = GetAttr<int>("y_num_row_dims");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(x_dim.size() > x_num_row_dims,
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() > x_num_row_dims,
|
|
|
|
|
"The rank of input tensor X(%s) should be larger than "
|
|
|
|
|
"`mul_op`'s `x_num_row_dims`.",
|
|
|
|
|
ctx.op().Input("X"));
|
|
|
|
|
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims,
|
|
|
|
|
PADDLE_ENFORCE(y_dims.size() > y_num_row_dims,
|
|
|
|
|
"The rank of input tensor Y(%s) should be larger than "
|
|
|
|
|
"`mul_op`'s `y_num_row_dims`.",
|
|
|
|
|
ctx.op().Input("Y"));
|
|
|
|
|
|
|
|
|
|
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_row_dims);
|
|
|
|
|
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_row_dims);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()),
|
|
|
|
|
product(y_dim, 0, y_dim.size() - y_num_row_dims),
|
|
|
|
|
x_mat_dims[1], y_mat_dims[0],
|
|
|
|
|
"First matrix's width must be equal with second matrix's height.");
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize(
|
|
|
|
|
{static_cast<int>(product(x_dim, 0, x_dim.size() - x_num_row_dims)),
|
|
|
|
|
static_cast<int>(
|
|
|
|
|
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))});
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -96,14 +96,18 @@ class MulOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
|
|
|
|
|
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
product(x_dims, 0, x_dims.size() - GetAttr<int>("x_num_row_dims")) ==
|
|
|
|
|
out_dims[0],
|
|
|
|
|
|
|
|
|
|
auto x_mat_dims =
|
|
|
|
|
framework::flatten_to_2d(x_dims, GetAttr<int>("x_num_row_dims"));
|
|
|
|
|
auto y_mat_dims =
|
|
|
|
|
framework::flatten_to_2d(y_dims, GetAttr<int>("y_num_row_dims"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_mat_dims[0], out_dims[0],
|
|
|
|
|
"The first dimension of Out@GRAD must equal to the first dimension of "
|
|
|
|
|
"the first operand.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
product(y_dims, y_dims.size() - GetAttr<int>("y_num_row_dims"),
|
|
|
|
|
y_dims.size()) == out_dims[1],
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
y_mat_dims[1], out_dims[1],
|
|
|
|
|
"The second dimension of Out@GRAD must equal to the second "
|
|
|
|
|
"dimension of the second operand.");
|
|
|
|
|
|
|
|
|
|