|
|
|
@ -27,8 +27,8 @@ class MulOp : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto y_dims = ctx.Input<Tensor>("Y")->dims();
|
|
|
|
|
int x_num_col_dims = GetAttr<int>("x_num_col_dims");
|
|
|
|
|
int y_num_col_dims = GetAttr<int>("y_num_col_dims");
|
|
|
|
|
int x_num_col_dims = Attr<int>("x_num_col_dims");
|
|
|
|
|
int y_num_col_dims = Attr<int>("y_num_col_dims");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
|
|
|
|
|
"The rank of input tensor X(%s) should be larger than "
|
|
|
|
@ -58,19 +58,19 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddOutput("Out", "The output of mul op");
|
|
|
|
|
AddAttr<int>(
|
|
|
|
|
"x_num_col_dims",
|
|
|
|
|
"mul_op can take tensors with more than two dimensions as input `X`, "
|
|
|
|
|
"in that case, tensors will be reshaped to a matrix. The matrix's "
|
|
|
|
|
"first dimension(column length) will be the product of tensor's last "
|
|
|
|
|
"`num_col_dims` dimensions, and the matrix's second dimension(row "
|
|
|
|
|
"length) will be the product of tensor's first `rank - num_col_dims` "
|
|
|
|
|
"dimensions.")
|
|
|
|
|
R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
|
|
|
|
|
in that case, tensors will be reshaped to a matrix. The matrix's first
|
|
|
|
|
dimension(column length) will be the product of tensor's last
|
|
|
|
|
`num_col_dims` dimensions, and the matrix's second dimension(row length)
|
|
|
|
|
will be the product of tensor's first `rank - num_col_dims` dimensions.
|
|
|
|
|
)DOC")
|
|
|
|
|
.SetDefault(1)
|
|
|
|
|
.EqualLargerThan(1);
|
|
|
|
|
AddAttr<int>(
|
|
|
|
|
"y_num_col_dims",
|
|
|
|
|
"mul_op can take tensors with more than two dimensions as input `Y`, "
|
|
|
|
|
"in that case, tensors will be reshaped to a matrix. Just like input "
|
|
|
|
|
"`X`.")
|
|
|
|
|
R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
|
|
|
|
|
in that case, tensors will be reshaped to a matrix. Just like input `X`.
|
|
|
|
|
)DOC")
|
|
|
|
|
.SetDefault(1)
|
|
|
|
|
.EqualLargerThan(1);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
@ -98,9 +98,9 @@ class MulOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
|
|
|
|
|
auto x_mat_dims =
|
|
|
|
|
framework::flatten_to_2d(x_dims, GetAttr<int>("x_num_col_dims"));
|
|
|
|
|
framework::flatten_to_2d(x_dims, Attr<int>("x_num_col_dims"));
|
|
|
|
|
auto y_mat_dims =
|
|
|
|
|
framework::flatten_to_2d(y_dims, GetAttr<int>("y_num_col_dims"));
|
|
|
|
|
framework::flatten_to_2d(y_dims, Attr<int>("y_num_col_dims"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_mat_dims[0], out_dims[0],
|
|
|
|
|