|
|
@ -57,7 +57,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
|
|
|
|
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
|
|
|
|
size_t x_rank = x_shape.size();
|
|
|
|
size_t x_rank = x_shape.size();
|
|
|
|
size_t y_rank = y_shape.size();
|
|
|
|
size_t y_rank = y_shape.size();
|
|
|
|
flag = flag && x_rank == 2 && y_rank == 2;
|
|
|
|
flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<Node*>& next_ops = matmul_out->outputs;
|
|
|
|
std::vector<Node*>& next_ops = matmul_out->outputs;
|
|
|
|
flag = flag && next_ops.size() == 1 &&
|
|
|
|
flag = flag && next_ops.size() == 1 &&
|
|
|
@ -69,7 +69,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
desc.SetInput("X", {matmul_in_x->Name()});
|
|
|
|
desc.SetInput("X", {matmul_in_x->Name()});
|
|
|
|
desc.SetInput("Y", {matmul_in_y->Name()});
|
|
|
|
desc.SetInput("Y", {matmul_in_y->Name()});
|
|
|
|
desc.SetOutput("Out", {matmul_out->Name()});
|
|
|
|
desc.SetOutput("Out", {matmul_out->Name()});
|
|
|
|
desc.SetAttr("x_num_col_dims", 1);
|
|
|
|
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
|
|
|
|
desc.SetAttr("y_num_col_dims", 1);
|
|
|
|
desc.SetAttr("y_num_col_dims", 1);
|
|
|
|
if (matmul_op->Op()->HasAttr("enable_int8")) {
|
|
|
|
if (matmul_op->Op()->HasAttr("enable_int8")) {
|
|
|
|
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
|
|
|
|
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
|
|
|
|