From e3f8e5cf5c8c8c28e14ecc04d75253947c54f924 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Fri, 28 Aug 2020 14:11:37 +0800 Subject: [PATCH] trt int8 support conv2d_transpose (#26636) --- .../ir/quant_conv2d_dequant_fuse_pass.cc | 30 +++++++++++++------ .../inference/tensorrt/convert/conv2d_op.cc | 8 ++++- .../slim/quantization/quantization_pass.py | 1 + 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 4506c162fa..56ae02d49e 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -81,7 +81,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope, if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" || quantized_op_type == "depthwise_conv2d" || - quantized_op_type == "fc") { + quantized_op_type == "fc" || + quantized_op_type == "conv2d_transpose") { op_desc->SetAttr("Input_scale", scale_value); } else if (quantized_op_type == "mul") { op_desc->SetAttr("X_scale", scale_value); @@ -111,7 +112,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope, std::string input_name = ""; if (quantized_op_type == "conv2d" || quantized_op_type == "depthwise_conv2d" || - quantized_op_type == "conv2d_fusion") { + quantized_op_type == "conv2d_fusion" || + quantized_op_type == "conv2d_transpose") { weight_name = "Filter"; input_name = "Input"; } else if (quantized_op_type == "mul") { @@ -122,7 +124,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope, input_name = "Input"; } else { PADDLE_THROW(platform::errors::Unimplemented( - "QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for " + "QuantDequantFuse: We only support conv2d, conv2d_fusion, " + "conv2d_transpose, fc, mul for " "now.")); } const std::string pattern_name = "dequant_fuse"; @@ -192,10 +195,12 @@ void FuseDequant(ir::Graph* graph, Scope* scope, scope->Var(quantized_op_weight_node->Name())->GetMutable(); auto w_dims = weight_tensor->dims(); // If quantized op is fc, weight scale size = 1; - // If quantized op is conv, weight scale size = weight dims[0] + // If quantized op is conv2d, weight scale size = weight dims[0] + // If quantized op is conv2d_transpose, weight scale size = weight dims[1] bool valid_scale_size = (weight_scale.size() == 1 || - weight_scale.size() == static_cast(w_dims[0])); + weight_scale.size() == static_cast(w_dims[0]) || + weight_scale.size() == static_cast(w_dims[1])); PADDLE_ENFORCE_EQ( valid_scale_size, true, platform::errors::InvalidArgument( @@ -206,8 +211,14 @@ void FuseDequant(ir::Graph* graph, Scope* scope, if (weight_scale.size() == 1) { quantized_weight_data[j] *= weight_scale[0]; } else { - int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; - quantized_weight_data[j] *= weight_scale[j / inner_size]; + if (quantized_op_type == "conv2d_transpose") { + int inner_size = w_dims[2] * w_dims[3]; + quantized_weight_data[j] *= + weight_scale[(j / inner_size) % w_dims[1]]; + } else { + int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; + quantized_weight_data[j] *= weight_scale[j / inner_size]; + } } } @@ -220,7 +231,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope, new_op_desc.SetType(quantized_op_type); new_op_desc.SetAttr("enable_int8", true); if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" || - quantized_op_type == "depthwise_conv2d") { + quantized_op_type == "depthwise_conv2d" || + quantized_op_type == "conv2d_transpose") { new_op_desc.SetInput("Input", {new_input}); new_op_desc.SetOutput("Output", {new_output}); } else if (quantized_op_type == "fc") { @@ -253,7 +265,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { std::unordered_set quant_types = { "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; std::unordered_set quantized_op_types = { - "conv2d", "mul", "depthwise_conv2d", "fc"}; + "conv2d", "mul", "depthwise_conv2d", "fc", "conv2d_transpose"}; auto* scope = param_scope(); for (auto& quant_type : quant_types) { diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 97d09925b1..10c212c0b4 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -51,7 +51,13 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, if (enable_int8) { #if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("Input_scale")); + if (op_desc.Type() != "conv2d_transpose") { + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("Input_scale"), true, + platform::errors::InvalidArgument("Input scale not found. TRT int8" + " requires conv/deconv to have " + "input quantization scales.")); + } float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; auto weight_scale = diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 14d1114a8f..b5a8d90194 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -68,6 +68,7 @@ _out_scale_op_list = [ "scale", "hard_swish", "hard_sigmoid", + "conv2d_transpose", ] # list op real input and output names, to avoid processing input such as AxisTensor.