|
|
|
@ -25,18 +25,26 @@ class MulOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
auto dim0 = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto dim1 = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim0.size(), 2,
|
|
|
|
|
"input X(%s) should be a tensor with 2 dims, a matrix",
|
|
|
|
|
ctx.op_.Input("X"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim1.size(), 2,
|
|
|
|
|
"input Y(%s) should be a tensor with 2 dims, a matrix",
|
|
|
|
|
ctx.op_.Input("Y"));
|
|
|
|
|
auto x_dim = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto y_dim = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|
int x_num_row_dims = GetAttr<int>("X_num_raw_dims");
|
|
|
|
|
int y_num_row_dims = GetAttr<int>("Y_num_raw_dims");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(x_dim.size() > x_num_row_dims,
|
|
|
|
|
"The rank of input tensor X(%s) should be larger than "
|
|
|
|
|
"`mul_op`'s `X_num_raw_dims`.",
|
|
|
|
|
ctx.op_.Input("X"));
|
|
|
|
|
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims,
|
|
|
|
|
"The rank of input tensor Y(%s) should be larger than "
|
|
|
|
|
"`mul_op`'s `Y_num_raw_dims`.",
|
|
|
|
|
ctx.op_.Input("Y"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim0[1], dim1[0],
|
|
|
|
|
product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()),
|
|
|
|
|
product(y_dim, 0, y_dim.size() - y_num_row_dims),
|
|
|
|
|
"First matrix's width must be equal with second matrix's height.");
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]});
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize(
|
|
|
|
|
{product(x_dim, 0, x_dim.size() - x_num_row_dims),
|
|
|
|
|
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size())});
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -47,6 +55,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X", "The first input of mul op");
|
|
|
|
|
AddInput("Y", "The second input of mul op");
|
|
|
|
|
AddOutput("Out", "The output of mul op");
|
|
|
|
|
AddAttr<int>(
|
|
|
|
|
"x_num_row_dims",
|
|
|
|
|
"mul_op can take tensors with more than two dimensions as input `X`, "
|
|
|
|
|
"in that case, tensors will be flattened to a matrix. The matrix's "
|
|
|
|
|
"second dimension(row length) will be the product of tensor's last "
|
|
|
|
|
"`num_row_dims` dimensions, and the matrix's first dimension(column "
|
|
|
|
|
"length) will be the product of tensor's first `rank - num_row_dims` "
|
|
|
|
|
"dimensions.")
|
|
|
|
|
.SetDefault(1)
|
|
|
|
|
.EqualLargerThan(1);
|
|
|
|
|
AddAttr<int>(
|
|
|
|
|
"y_num_row_dims",
|
|
|
|
|
"mul_op can take tensors with more than two dimensions as input `Y`, "
|
|
|
|
|
"in that case, tensors will be flattened to a matrix. Just like input "
|
|
|
|
|
"`X`.")
|
|
|
|
|
.SetDefault(1)
|
|
|
|
|
.EqualLargerThan(1);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Two Element Mul Operator.
|
|
|
|
|
|
|
|
|
@ -70,10 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
|
|
|
|
|
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
PADDLE_ENFORCE(x_dims[0] == out_dims[0],
|
|
|
|
|
"Out@GRAD M X N must equal to X dims 0, M ");
|
|
|
|
|
PADDLE_ENFORCE(y_dims[1] == out_dims[1],
|
|
|
|
|
"Out@GRAD M X N must equal to Y dims 1, N ");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
product(x_dim, 0, x_dims.size() - x_num_row_dims) == out_dims[0],
|
|
|
|
|
"The first dimension of Out@GRAD must equal to the first dimension of "
|
|
|
|
|
"the first operand.");
|
|
|
|
|
PADDLE_ENFORCE(product(y_dim, y_dims.size() - y_num_row_dims,
|
|
|
|
|
y_dims.size()) == out_dims[1],
|
|
|
|
|
"The second dimension of Out@GRAD must equal to the second "
|
|
|
|
|
"dimension of the second operand.");
|
|
|
|
|
|
|
|
|
|
x_grad->Resize(x_dims);
|
|
|
|
|
y_grad->Resize(y_dims);
|
|
|
|
|