map_matmul_to_mul_pass support 3dim (#31958)

develop
Pei Yang 4 years ago committed by GitHub
parent a37a7f67e1
commit 98e803e04f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -57,7 +57,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape(); std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
size_t x_rank = x_shape.size(); size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size(); size_t y_rank = y_shape.size();
flag = flag && x_rank == 2 && y_rank == 2; flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs; std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 && flag = flag && next_ops.size() == 1 &&
@ -69,7 +69,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetInput("X", {matmul_in_x->Name()}); desc.SetInput("X", {matmul_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()}); desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()}); desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1); desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));

Loading…
Cancel
Save