From cc2f94620c537d2ff05862fe8445ad379008047c Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 30 Dec 2020 18:43:43 +0800 Subject: [PATCH] add the support the op version check for matmul, test=op_version (#30011) * add the support the op version check for matmul, test=op_version --- paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc | 6 +++--- .../ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc | 2 +- .../framework/ir/mkldnn/scale_matmul_fuse_pass.cc | 2 +- .../fluid/framework/ir/multihead_matmul_fuse_pass.cc | 2 +- .../fluid/framework/ir/squared_mat_sub_fuse_pass.cc | 2 +- .../analysis/ir_passes/tensorrt_subgraph_pass.cc | 2 +- paddle/fluid/operators/matmul_op.cc | 12 ++++++++++++ 7 files changed, 20 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc index 76148a9007..8c4e6f3305 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc @@ -227,7 +227,7 @@ REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass); REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("mul", 0)); REGISTER_PASS(squeeze2_matmul_fuse_pass, @@ -235,7 +235,7 @@ REGISTER_PASS(squeeze2_matmul_fuse_pass, REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("squeeze2", 0) .EQ("mul", 0)); @@ -244,6 +244,6 @@ REGISTER_PASS(reshape2_matmul_fuse_pass, REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("reshape2", 0) .EQ("mul", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc index 41b859f0af..fbc97a0a92 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.cc @@ -103,6 +103,6 @@ REGISTER_PASS(matmul_transpose_reshape_fuse_pass, REGISTER_PASS_CAPABILITY(matmul_transpose_reshape_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("transpose", 0) .EQ("reshape", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc index 0784a1a024..a552e42619 100644 --- a/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.cc @@ -96,4 +96,4 @@ REGISTER_PASS_CAPABILITY(scale_matmul_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .EQ("scale", 0) - .EQ("matmul", 0)); + .LE("matmul", 1)); diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index bb9613d0c1..224272a5a0 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -720,5 +720,5 @@ REGISTER_PASS_CAPABILITY(multihead_matmul_fuse_pass_v2) .EQ("reshape2", 0) .EQ("transpose2", 0) .EQ("scale", 0) - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("softmax", 0)); diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc index d17212f4aa..c0420e6b5f 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -389,7 +389,7 @@ REGISTER_PASS(squared_mat_sub_fuse_pass, REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("matmul", 0) + .LE("matmul", 1) .EQ("matmul_v2", 0) .EQ("square", 0) .LE("elementwise_mul", 1) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index a67908fe7f..4bd804dfca 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -396,4 +396,4 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .EQ("gelu", 0) .EQ("layer_norm", 0) .EQ("scale", 0) - .EQ("matmul", 0)); + .LE("matmul", 1)); diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index d45669a9f0..668445d242 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/math/blas.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -932,3 +933,14 @@ REGISTER_OP_CUDA_KERNEL( ops::MatMulDoubleGradKernel, ops::MatMulDoubleGradKernel); #endif + +REGISTER_OP_VERSION(matmul) + .AddCheckpoint( + R"ROC(Register matmul for adding the attribute of + fused_reshape_Y)ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "fused_reshape_Y", + "In order to support the function of fused the input Y " + " and input X into the input X when " + "using the operator of matmul, and get raw shape of input Y.", + std::vector{}));