[Paddle-TRT] Better Paddle-TensorRT support for PaddleSlim quant models (#25097)

* Paddle-TensorRT support slim QAT. test=develop

* add comments. test=develop

* use RenameInput instead of ResetInputs. test=develop
fix_copy_if_different
Pei Yang 5 years ago committed by GitHub
parent a965ac4c61
commit b2f5a149e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1980,99 +1980,58 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return concat_out;
}
void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times,
const std::string &quant_type,
const std::string &dequant_type) {
int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
const int kDequantOpWeightScaleOffset = 5;
// the quant op always be one.
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
const std::string &quant_type) {
auto *input_scale_node = pattern->NewNode(GetNodeName("input_scale_node"))
->assert_is_op_input(quant_type, "InScale")
->AsInput();
auto quant_op =
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);
PDNode *quant_op_out_scale = nullptr;
auto *quant_node =
pattern->NewNode(GetNodeName("quant_node"))->assert_is_op(quant_type);
auto *output_scale_node = pattern->NewNode(GetNodeName("output_scale_node"))
->assert_is_op_output(quant_type, "OutScale")
->AsOutput();
auto *output_act_node = pattern->NewNode(GetNodeName("output_act_node"))
->assert_is_op_output(quant_type, "Out")
->AsOutput();
quant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_node});
output_act_node->LinksFrom({quant_node});
}
void patterns::DequantOpFuse::operator()(PDNode *quantized_op_input,
const std::string &quantized_op_type,
const std::string &dequant_type,
const std::string &weight_name) {
auto *quantized_op_weight =
pattern->NewNode(GetNodeName("quantized_op_weight"))
->assert_is_op_input(quantized_op_type, weight_name)
->AsInput();
auto *quantized_op = pattern->NewNode(GetNodeName("quantized_op"))
->assert_is_op(quantized_op_type);
auto *quantized_op_out = pattern->NewNode(GetNodeName("quantized_op_out"))
->assert_is_op_output(quantized_op_type)
->assert_is_op_input(dequant_type, "X");
auto *dequant_op =
pattern->NewNode(GetNodeName("dequant_op"))->assert_is_op(dequant_type);
auto *dequant_op_out = pattern->NewNode(GetNodeName("dequant_op_out"))
->assert_is_op_output(dequant_type, "Out")
->AsOutput();
PDNode *dequant_channel_scale = nullptr;
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
kNumFields += 1;
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_nth_input(dequant_type, "Scales", 1)
->AsIntermediate();
} else {
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output(quant_type, "OutScale")
->assert_is_op_input(dequant_type, "Scale")
->AsIntermediate();
dequant_channel_scale =
pattern->NewNode(GetNodeName("dequant_channel_scale"))
->assert_is_op_nth_input(dequant_type, "Scales", 0)
->AsInput();
}
quantized_op->LinksFrom({quantized_op_input, quantized_op_weight});
quantized_op_out->LinksFrom({quantized_op});
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output(quant_type, "Out")
->assert_is_op_input(op_type)
->AsIntermediate();
// there are 'times' quantized and dequant op
std::vector<PDNode *> nodes;
for (int i = 0; i < times; i++) {
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_weight") + std::to_string(i))
->assert_is_op_input(op_type, weight_name)
->AsInput());
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op") + std::to_string(i))
->assert_is_op(op_type));
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i))
->assert_is_op_output(op_type)
->assert_is_op_input(dequant_type, "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i))
->assert_is_op(dequant_type));
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i))
->assert_is_op_output(dequant_type, "Out")
->AsOutput());
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes.push_back(pattern
->NewNode(GetNodeName("dequant_channel_scale") +
std::to_string(i))
->assert_is_op_nth_input(dequant_type, "Scales", 0)
->AsInput());
}
}
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
quant_op_out->LinksFrom({quant_op});
for (int i = 0; i < times; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale,
nodes[i * kNumFields + kDequantOpWeightScaleOffset]});
} else {
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
}
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
dequant_op->LinksFrom({quantized_op_out, dequant_channel_scale});
} else {
dequant_op->LinksFrom({quantized_op_out});
}
dequant_op_out->LinksFrom({dequant_op});
}
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {

@ -1150,14 +1150,28 @@ struct TransposeFlattenConcat : public PatternBase {
}
};
struct QuantDequantOpFuse : public PatternBase {
QuantDequantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& op_name,
const std::string& weight_name, int times,
const std::string& quant_type,
const std::string& dequant_type);
struct DeleteQuantOpFuse : public PatternBase {
DeleteQuantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_quant_fuse") {}
void operator()(PDNode* input_act_node, const std::string& quant_type);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
}
PDNode* GetPDNode(const std::string& op_type) {
return pattern->RetrieveNode(GetNodeName(op_type));
}
};
struct DequantOpFuse : public PatternBase {
DequantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "dequant_fuse") {}
void operator()(PDNode* quant_op_input, const std::string& quantized_op_type,
const std::string& dequant_type,
const std::string& weight_name);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);

File diff suppressed because it is too large Load Diff

@ -22,6 +22,9 @@ namespace paddle {
namespace framework {
namespace ir {
///
/// Fuse quant + conv2d/depthwise_conv2d/mul/fc + dequant
///
class QuantDequantFusePass : public FusePassBase {
public:
virtual ~QuantDequantFusePass() {}

@ -365,6 +365,10 @@ const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
return it->second;
}
bool OpDesc::HasOutput(const std::string &name) const {
return outputs_.find(name) != outputs_.end();
}
std::vector<std::string> OpDesc::OutputArgumentNames() const {
std::vector<std::string> retv;
for (auto &ipt : this->outputs_) {

@ -57,6 +57,8 @@ class OpDesc {
const std::vector<std::string> &Output(const std::string &name) const;
bool HasOutput(const std::string &name) const;
std::vector<std::string> OutputArgumentNames() const;
void SetOutput(const std::string &param_name,

@ -281,11 +281,8 @@ void AnalysisConfig::Update() {
if (use_tensorrt_) {
pass_builder()->ClearPasses();
bool use_calib_int8 =
(tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8) &&
trt_use_calib_mode_;
for (const auto &pass : kTRTSubgraphPasses) {
if (use_calib_int8 &&
if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
(pass == "conv_bn_fuse_pass" || pass == "fc_fuse_pass")) {
continue;
}

@ -52,7 +52,8 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
if (enable_int8) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("Input_scale"));
float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t,

@ -62,7 +62,7 @@ class FcOpConverter : public OpConverter {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr(i_name + "_scale"));
float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale"));
BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(),

@ -98,8 +98,33 @@ class OpConverter {
}
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
op_desc.Type());
it->SetEngine(engine);
(*it)(op, scope, test_mode);
bool has_out_scale = op_desc.HasAttr("out_threshold");
if (has_out_scale) {
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
std::string output_name = "";
if (op_desc.HasOutput("Output")) {
output_name = op_desc.Output("Output").front();
} else if (op_desc.HasOutput("Out")) {
output_name = op_desc.Output("Out").front();
} else if (op_desc.HasOutput("Y")) {
output_name = op_desc.Output("Y").front();
} else {
PADDLE_THROW(
platform::errors::NotFound("Op %s has out threshold but doesn't "
"have an output named \"Output\", "
"\"Out\" or \"Y\".",
op_desc.Type()));
}
auto* output_itensor = engine->GetITensor(output_name);
engine->SetTensorDynamicRange(output_itensor, out_scale);
VLOG(1) << "Set out scale = " << out_scale << " for tensor "
<< output_name << ".";
}
}
// Convert a fluid block to tensorrt network, NOTE it just convert operators,

@ -124,23 +124,42 @@ void TensorRTEngine::FreezeNetwork() {
<< ", this might be ok when trt does not need this range";
}
}
std::unordered_set<std::string> all_out_t_name;
for (int i = 0; i < network()->getNbOutputs(); i++) {
auto *temp = network()->getOutput(i);
temp->setDynamicRange(-1, 1);
all_out_t_name.insert(temp->getName());
}
for (int i = 0; i < network()->getNbLayers(); i++) {
auto layer = network()->getLayer(i);
auto is_layer_int8 = [&](nvinfer1::ILayer *layer) -> bool {
for (int j = 0; j < layer->getNbInputs(); j++) {
auto *temp_in = layer->getInput(j);
if (!temp_in->dynamicRangeIsSet()) {
VLOG(1) << "Layer(Name: " << layer->getName()
<< ") is set to float32 because its input("
<< temp_in->getName() << ") doesn't have dynamic range.";
return false;
}
}
for (int j = 0; j < layer->getNbOutputs(); j++) {
auto *temp_out = layer->getOutput(j);
if (std::find(all_out_t_name.begin(), all_out_t_name.end(),
temp_out->getName()) != all_out_t_name.end()) {
layer->setPrecision(nvinfer1::DataType::kFLOAT);
layer->setOutputType(j, nvinfer1::DataType::kFLOAT);
if (temp_out->isNetworkOutput()) {
VLOG(1) << "Layer(Name: " << layer->getName()
<< ") is set to float32 because its output("
<< temp_out->getName() << ") is the output of the network.";
return false;
}
if (!temp_out->dynamicRangeIsSet()) {
VLOG(1) << "Layer(Name: " << layer->getName()
<< ") is set to float32 because its output("
<< temp_out->getName() << ") doesn't have dynamic range.";
return false;
}
}
return true;
};
// If a layer's output is the network's output, or not all of its inputs
// and outputs have scales,
// this layer's precision and output type are set to float32.
// This step has no effect if this layer is fused during TRT optimization.
for (int i = 0; i < network()->getNbLayers(); i++) {
auto layer = network()->getLayer(i);
if (!is_layer_int8(layer)) {
layer->setPrecision(nvinfer1::DataType::kFLOAT);
}
}
#endif
}
@ -237,7 +256,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
std::string name_suffix = std::to_string(name_suffix_counter);
std::string splitter = "__";
std::string name_with_suffix = name + splitter + name_suffix;
auto w_dims = weight_tensor->dims();
platform::CPUPlace cpu_place;
PADDLE_ENFORCE_EQ(
weight_map.count(name_with_suffix), 0,
@ -250,25 +268,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
float *weight_data =
weight_map[name_with_suffix]->mutable_data<float>(cpu_place);
name_suffix_counter += 1;
if (enable_int8) {
// when the op is fc, scale's size should be 1
// when the op is conv, scale's size should be w_dims[0]
bool valid_scale_size =
(scale.size() == 1 || scale.size() == static_cast<size_t>(w_dims[0]));
PADDLE_ENFORCE(valid_scale_size, "TRT int8 quant: invalid scale size");
for (int i = 0; i < weight_tensor->numel(); i++) {
if (scale.size() == 1) {
weight_data[i] *= (scale[0] / 127);
} else {
PADDLE_ENFORCE(w_dims.size() == 4,
"TRT int8 quant : We only use the channel quant for "
"conv op, so the weight dims should be 4.");
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
weight_data[i] *= (scale[i / inner_size] / 127);
}
}
}
return weight_data;
}

@ -43,11 +43,18 @@ struct SimpleOpTypeSetTeller : public Teller {
private:
// use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{
"mul", "conv2d", "pool2d",
"relu", "depthwise_conv2d", "softmax",
"batch_norm", "elementwise_add", "leaky_relu",
"fc"};
std::unordered_set<std::string> int8_teller_set{"mul",
"conv2d",
"pool2d",
"relu",
"depthwise_conv2d",
"softmax",
"batch_norm",
"elementwise_add",
"leaky_relu",
"fc",
"relu6",
"concat"};
std::unordered_set<std::string> teller_set{
"mul",
"conv2d",

@ -405,6 +405,14 @@ if(WITH_GPU AND TENSORRT_FOUND)
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR})
set(TRT_MODEL_QUANT_YOLOV3_DIR "${INFERENCE_DEMO_INSTALL_DIR}/yolov3_r50_quant_aware")
if (NOT EXISTS ${TRT_MODEL_QUANT_YOLOV3_DIR})
inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "yolov3_r50_quant_aware.tgz")
endif()
inference_analysis_test(trt_quant_int8_yolov3_r50_test SRCS trt_quant_int8_yolov3_r50_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_QUANT_YOLOV3_DIR})
set(TEST_TRT_DYNAMIC_MODEL2 "${TRT_MODEL_INSTALL_DIR}/complex_model_dynamic")
if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL2})
inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL2} ${INFERENCE_URL}/tensorrt_test "complex_model_dynamic2.tar.gz")

@ -0,0 +1,63 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <numeric>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(quant_int8, yolov3_resnet50) {
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params");
config.SwitchUseFeedFetchOps(false);
config.EnableTensorRtEngine(1 << 30, 1, 3, AnalysisConfig::Precision::kInt8,
false, false);
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
int channels = 3;
int height = 608;
int width = 608;
int input_num = channels * height * width * 1;
float *input = new float[input_num];
int32_t *im_shape = new int32_t[2];
im_shape[0] = 608;
im_shape[1] = 608;
memset(input, 1.0, input_num * sizeof(float));
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({1, channels, height, width});
input_t->copy_from_cpu(input);
auto input_t1 = predictor->GetInputTensor(input_names[1]);
input_t1->Reshape({1, 2});
input_t1->copy_from_cpu(im_shape);
ASSERT_TRUE(predictor->ZeroCopyRun());
std::vector<float> out_data;
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
}
} // namespace inference
} // namespace paddle
Loading…
Cancel
Save