diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index 9c984a23e3..c0ebf6de9d 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -238,11 +238,11 @@ REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass, REGISTER_PASS_CAPABILITY(conv_affine_channel_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("affine_channel", 0)); REGISTER_PASS_CAPABILITY(conv_eltwiseadd_affine_channel_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("affine_channel", 0)); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index a915015bf5..72ac7c3b0e 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -383,11 +383,11 @@ REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("batch_norm", 0)); REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("batch_norm", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index ad6af69ae0..545beb34e7 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h" + #include + #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -119,7 +121,7 @@ REGISTER_PASS(conv_elementwise_add2_act_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("relu", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index 93e6e13ff7..d01a2f2622 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h" + #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" @@ -107,7 +108,7 @@ REGISTER_PASS(conv_elementwise_add_act_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0) .EQ("relu", 0) .EQ("identity", 0)); diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc index e4396f227f..e34a2d9658 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h" + #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" @@ -93,5 +94,5 @@ REGISTER_PASS(conv_elementwise_add_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index c33398553e..d0bdeb9ad8 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" + #include + #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -107,7 +109,7 @@ REGISTER_PASS(conv_relu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_relu_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("relu", 0)); REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, @@ -115,7 +117,7 @@ REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_leaky_relu_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .LE("leaky_relu", 1)); REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, @@ -123,7 +125,7 @@ REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_relu6_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("relu6", 0)); REGISTER_PASS(conv_swish_mkldnn_fuse_pass, @@ -131,5 +133,5 @@ REGISTER_PASS(conv_swish_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_swish_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("swish", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 716c49dcb1..b0849d74b6 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h" + #include #include + #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -150,7 +152,7 @@ REGISTER_PASS(conv_bias_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_bias_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0)); REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc index 76e1021255..c4d7a12037 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h" + #include + #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" @@ -128,6 +130,6 @@ REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("concat", 0) .EQ("relu", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index 2fb131acea..a837b42b3e 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -13,11 +13,13 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" + #include #include #include #include #include + #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/op_version_registry.h" @@ -226,19 +228,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); conv_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_elementwise_add = + [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); - - return std::make_tuple(elementwise_add_op, elementwise_add_y, - elementwise_add_out); - }; + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + return std::make_tuple(elementwise_add_op, elementwise_add_y, + elementwise_add_out); + }; return ExecuteHandleOnGraph( &gpd, graph_with_stats, @@ -263,19 +266,20 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( conv_output); conv_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_elementwise_add = + [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); - - return std::make_tuple(elementwise_add_op, elementwise_add_x, - elementwise_add_out); - }; + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); + + return std::make_tuple(elementwise_add_op, elementwise_add_x, + elementwise_add_out); + }; return ExecuteHandleOnGraph( &gpd, graph_with_stats, @@ -302,16 +306,17 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( conv_x_output->AsIntermediate(); conv_y_output->AsIntermediate(); - auto get_node_from_elementwise_add = [&elementwise_add_pattern]( - const GraphPatternDetector::subgraph_t& subgraph) + auto get_node_from_elementwise_add = + [&elementwise_add_pattern]( + const GraphPatternDetector::subgraph_t& subgraph) -> std::tuple { - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, - elementwise_add_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, - elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, + elementwise_add_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + elementwise_add_pattern); - return std::make_tuple(elementwise_add_op, elementwise_add_out); - }; + return std::make_tuple(elementwise_add_op, elementwise_add_out); + }; return ExecuteHandleOnGraph( &gpd, graph_with_stats, @@ -345,5 +350,5 @@ REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, REGISTER_PASS_CAPABILITY(conv_elementwise_add_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("elementwise_add", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc index b2c0afdc75..39f47406a7 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc @@ -63,5 +63,5 @@ REGISTER_PASS(depthwise_conv_mkldnn_pass, paddle::framework::ir::DepthwiseConvMKLDNNPass); REGISTER_PASS_CAPABILITY(depthwise_conv_mkldnn_pass) .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination().EQ( - "depthwise_conv2d", 0)); + paddle::framework::compatible::OpVersionComparatorCombination().LE( + "depthwise_conv2d", 1)); diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 895c396e1e..96c5546d21 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h" + #include #include #include #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" -#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -331,7 +332,7 @@ REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass); REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("fc", 0) .LE("conv2d_transpose", 1) .EQ("fake_quantize_abs_max", 0) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 08f3d609fa..bf0d87da91 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" + #include #include #include @@ -20,7 +22,6 @@ #include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/inference/analysis/helper.h" -#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/op_teller.h" @@ -309,6 +310,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( min_input_shape, max_input_shape, opt_input_shape, disable_trt_plugin_fp16); trt_engine->SetUseOSS(Get("use_oss")); + trt_engine->SetWithErnie( graph->Has(framework::ir::kEmbEltwiseLayernormPass) && graph->Has(framework::ir::kMultiheadMatmulPass)); @@ -367,13 +369,13 @@ REGISTER_PASS(tensorrt_subgraph_pass, REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() - .EQ("conv2d", 0) + .LE("conv2d", 1) .EQ("pool2d", 0) .EQ("relu", 0) .EQ("softmax", 0) .EQ("sigmoid", 0) .EQ("hard_swish", 0) - .EQ("depthwise_conv2d", 0) + .LE("depthwise_conv2d", 1) .EQ("batch_norm", 0) .EQ("concat", 0) .EQ("tanh", 0) diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index ef8a2b38f2..76ff1084fa 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/op_version_registry.h" + #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_helper.h" @@ -817,3 +819,36 @@ REGISTER_OP_CPU_KERNEL( conv3d_grad_grad, ops::GemmConvDoubleGradKernel, ops::GemmConvDoubleGradKernel); + +REGISTER_OP_VERSION(conv2d) + .AddCheckpoint( + R"ROC( + Upgrade conv2d, add a new attribute [use_addto]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_addto", + "In order to support new feature (inplace addto strategy) for " + "gradient accumulation.", + false)); + +REGISTER_OP_VERSION(depthwise_conv2d) + .AddCheckpoint( + R"ROC( + Upgrade depthwise_conv2d, add a new attribute [use_addto]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_addto", + "In order to support new feature (inplace addto strategy) for " + "gradient accumulation.", + false)); + +REGISTER_OP_VERSION(conv3d) + .AddCheckpoint( + R"ROC( + Upgrade conv3d, add a new attribute [use_addto]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_addto", + "In order to support new feature (inplace addto strategy) for " + "gradient accumulation.", + false));