From 6d86efc1d83ec19f99299841d5f4df000e9898e3 Mon Sep 17 00:00:00 2001 From: fuzhiye Date: Mon, 4 Jan 2021 14:42:00 +0800 Subject: [PATCH] fix quantized rounding --- .../assembly/arm64/ConvDwInt8PostAlign4.S | 35 +++++++++++++---- .../arm64/ConvDwInt8PostAlign4PerChannel.S | 15 ++++++-- .../lite/nnacl/int8/quant_dtype_cast_int8.c | 9 ++++- .../lite/nnacl/int8/quant_dtype_cast_int8.h | 3 +- mindspore/lite/nnacl/op_base.h | 6 +++ .../lite/nnacl/quantization/fixed_point.c | 15 +++++++- .../lite/nnacl/quantization/fixed_point.h | 5 +++ mindspore/lite/nnacl/quantization/quantize.c | 28 +++++++++++++- mindspore/lite/nnacl/quantization/quantize.h | 9 ++++- mindspore/lite/schema/model.fbs | 2 + mindspore/lite/src/lite_session.cc | 3 ++ .../kernel/arm/base/convolution_base.cc | 38 +++++++++++++++++-- .../kernel/arm/base/convolution_base.h | 1 + .../kernel/arm/base/quant_dtype_cast.cc | 12 +++++- .../convolution_depthwise_slidewindow_int8.cc | 4 +- .../kernel/arm/int8/fullconnection_int8.cc | 4 +- .../runtime/kernel/arm/int8/matmul_int8.cc | 4 +- .../src/runtime/kernel/arm/int8/relux_int8.cc | 3 +- .../runtime/kernel/arm/int8/resize_int8.cc | 4 +- mindspore/lite/src/tensor.h | 3 ++ .../kernel/arm/int8/matmul_int8_tests.cc | 2 +- .../lite/tools/anf_exporter/anf_exporter.cc | 5 +++ mindspore/lite/tools/common/graph_util.cc | 10 ++++- mindspore/lite/tools/common/tensor_util.cc | 2 + mindspore/lite/tools/converter/converter.cc | 1 + .../lite/tools/converter/converter_flags.cc | 2 +- .../lite/tools/converter/converter_flags.h | 9 ++++- .../graph/dtype_trans_pass.cc | 2 +- .../set_unused_quant_param_to_default_pass.cc | 1 - .../graph/tensor_quant_pass.cc | 7 ++-- .../parser/caffe/caffe_model_parser.cc | 1 + .../parser/onnx/onnx_model_parser.cc | 10 ++++- .../converter/parser/tf/tf_model_parser.cc | 2 + .../parser/tflite/tflite_model_parser.cc | 14 +++++-- .../parser/tflite/tflite_model_parser.h | 3 +- .../quantizer/post_training_quantizer.cc | 4 ++ 36 files changed, 232 insertions(+), 46 deletions(-) diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4.S b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4.S index 2345a3714c..d78589dbe1 100644 --- a/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4.S +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4.S @@ -53,10 +53,22 @@ ConvDwInt8PostAlign4: sqrdmulh v2.4s, v2.4s, v27.4s sqrdmulh v3.4s, v3.4s, v27.4s - sqrshl v0.4s, v0.4s, v28.4s - sqrshl v1.4s, v1.4s, v28.4s - sqrshl v2.4s, v2.4s, v28.4s - sqrshl v3.4s, v3.4s, v28.4s + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + and v5.16b, v1.16b, v28.16b + sshr v5.4s, v5.4s, #31 + sqadd v1.4s, v1.4s, v5.4s + srshl v1.4s, v1.4s, v28.4s + and v6.16b, v2.16b, v28.16b + sshr v6.4s, v6.4s, #31 + sqadd v2.4s, v2.4s, v6.4s + srshl v2.4s, v2.4s, v28.4s + and v7.16b, v3.16b, v28.16b + sshr v7.4s, v7.4s, #31 + sqadd v3.4s, v3.4s, v7.4s + srshl v3.4s, v3.4s, v28.4s AddZpDepth16: add v0.4s, v0.4s, v29.4s @@ -109,8 +121,14 @@ ConvDwInt8PostAlign4: RightShiftDepth8: sqrdmulh v0.4s, v0.4s, v27.4s sqrdmulh v1.4s, v1.4s, v27.4s - sqrshl v0.4s, v0.4s, v28.4s - sqrshl v1.4s, v1.4s, v28.4s + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s + and v5.16b, v1.16b, v28.16b + sshr v5.4s, v5.4s, #31 + sqadd v1.4s, v1.4s, v5.4s + srshl v1.4s, v1.4s, v28.4s AddZpDepth8: add v0.4s, v0.4s, v29.4s @@ -140,7 +158,10 @@ ConvDwInt8PostAlign4: sqshl v0.4s, v0.4s, v26.4s sqrdmulh v0.4s, v0.4s, v27.4s - sqrshl v0.4s, v0.4s, v28.4s + and v4.16b, v0.16b, v28.16b + sshr v4.4s, v4.4s, #31 + sqadd v0.4s, v0.4s, v4.4s + srshl v0.4s, v0.4s, v28.4s add v0.4s, v0.4s, v29.4s smax v0.4s, v0.4s, v30.4s diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S index 5d54dd79ab..35c2eb7dd8 100644 --- a/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8PostAlign4PerChannel.S @@ -43,8 +43,14 @@ ConvDwInt8PostAlign4PerChannel: sqrdmulh v0.4s, v0.4s, v4.4s sqrdmulh v1.4s, v1.4s, v5.4s - sqrshl v0.4s, v0.4s, v6.4s - sqrshl v1.4s, v1.4s, v7.4s + and v16.16b, v0.16b, v6.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s + and v17.16b, v1.16b, v7.16b + sshr v17.4s, v17.4s, #31 + sqadd v1.4s, v1.4s, v17.4s + srshl v1.4s, v1.4s, v7.4s add v0.4s, v0.4s, v29.4s add v1.4s, v1.4s, v29.4s @@ -80,7 +86,10 @@ ConvDwInt8PostAlign4PerChannel: sqrdmulh v0.4s, v0.4s, v4.4s ld1 {v6.4s}, [x6], #16 - sqrshl v0.4s, v0.4s, v6.4s + and v16.16b, v0.16b, v6.16b + sshr v16.4s, v16.4s, #31 + sqadd v0.4s, v0.4s, v16.4s + srshl v0.4s, v0.4s, v6.4s add v0.4s, v0.4s, v29.4s smax v0.4s, v0.4s, v30.4s diff --git a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c index 102ab6b091..0ec6fc72f5 100644 --- a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c +++ b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.c @@ -29,17 +29,24 @@ int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float return NNACL_OK; } -int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size) { +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + bool uint8_flag) { if (quant_values == NULL || real_values == NULL) { return NNACL_PARAM_INVALID; } + if (uint8_flag) { + zp += 128; + } const float inverse_scale = 1.0f / scale; for (int i = 0; i < size; ++i) { if (isinf(real_values[i])) { quant_values[i] = 127; } else { int temp = round(real_values[i] * inverse_scale + zp); + if (uint8_flag) { + temp -= 128; + } temp = temp < 127 ? temp : 127; temp = temp > -128 ? temp : -128; quant_values[i] = (int8_t)temp; diff --git a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h index e5e843f9ec..cc61782c6b 100644 --- a/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h +++ b/mindspore/lite/nnacl/int8/quant_dtype_cast_int8.h @@ -29,7 +29,8 @@ typedef struct QuantDTypeCastParameter { extern "C" { #endif int DoDequantizeInt8ToFp32(const int8_t *quant_values, float *real_values, float scale, int32_t zp, int size); -int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size); +int DoQuantizeFp32ToInt8(const float *real_values, int8_t *quant_values, float scale, int32_t zp, int size, + bool uint8_flag); int DoDequantizeUInt8ToFp32(const uint8_t *quant_values, float *real_values, float scale, int32_t zp, int size); int DoQuantizeFp32ToUInt8(const float *real_values, uint8_t *quant_values, float scale, int32_t zp, int size); int Int8ToUInt8(const int8_t *quant_values, uint8_t *real_values, int size); diff --git a/mindspore/lite/nnacl/op_base.h b/mindspore/lite/nnacl/op_base.h index 645c429dcf..eb3b64fbe0 100644 --- a/mindspore/lite/nnacl/op_base.h +++ b/mindspore/lite/nnacl/op_base.h @@ -80,6 +80,12 @@ typedef struct OpParameter { typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6, ActType_Prelu } ActType; typedef enum PadMode { Pad_No, Pad_Same, Pad_Valid } PadMode; +typedef enum RoundingMode { Rounding_No, Rounding_Away_from_zero, Rounding_Up } RoundingMode; +typedef enum CalFixedMultiplierMode { + Method_No, + Method_SinglePrecision, + Method_DoublePrecision +} CalFixedMultiplierMode; #ifdef ENABLE_ARM #define MS_FLOAT32X4 float32x4_t diff --git a/mindspore/lite/nnacl/quantization/fixed_point.c b/mindspore/lite/nnacl/quantization/fixed_point.c index bfdc17cacb..da4b940c4f 100644 --- a/mindspore/lite/nnacl/quantization/fixed_point.c +++ b/mindspore/lite/nnacl/quantization/fixed_point.c @@ -42,7 +42,7 @@ int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b) { } // division by a 2^exponent with rounding -// or arithmetic right shift with rouding +// or arithmetic right shift with rounding int RoundingDivideByPOT(int x, int exponent) { const int mask = (1ll << exponent) - 1; const int remainder = x & mask; @@ -50,10 +50,23 @@ int RoundingDivideByPOT(int x, int exponent) { return (x >> exponent) + (remainder > threshold ? 1 : 0); } +int UpwardRounding(int x, int exponent) { + const int32_t rounding_offset = (exponent > 0) ? (1 << (exponent - 1)) : 0; + if (x > INT32_MAX - rounding_offset) { + return 1 << (31 - exponent); + } + return (x + rounding_offset) >> exponent; +} + int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift) { return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); } +int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, + int32_t right_shift) { + return UpwardRounding(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); +} + int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift) { return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value, multiplier), right_shift); } diff --git a/mindspore/lite/nnacl/quantization/fixed_point.h b/mindspore/lite/nnacl/quantization/fixed_point.h index 6d81a0a8e6..ed106b184d 100644 --- a/mindspore/lite/nnacl/quantization/fixed_point.h +++ b/mindspore/lite/nnacl/quantization/fixed_point.h @@ -40,8 +40,13 @@ int16_t SaturatingRoundingDoublingHighMulInt16(int16_t a, int16_t b); // or arithmetic right shift with rouding int RoundingDivideByPOT(int x, int exponent); +int UpwardRounding(int x, int exponent); + int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift); +int MultiplyByQuantizedMultiplierWithUpwardRounding(int32_t value, int32_t multiplier, int32_t left_shift, + int32_t right_shift); + int MultiplyByMultiplierAndRightShift(int32_t value, int32_t multiplier, int32_t right_shift); int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent); diff --git a/mindspore/lite/nnacl/quantization/quantize.c b/mindspore/lite/nnacl/quantization/quantize.c index ce3ed1ca4d..2da215a567 100644 --- a/mindspore/lite/nnacl/quantization/quantize.c +++ b/mindspore/lite/nnacl/quantization/quantize.c @@ -15,6 +15,7 @@ */ #include "nnacl/quantization/quantize.h" +#include const uint64_t dSignMask = 1ull << 63; const uint64_t dExponentMask = 0x7ffull << 52; @@ -35,8 +36,8 @@ void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantiz *right_shift = -shift; } -void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, - int *right_shift) { +void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, + int *right_shift) { int shift = 0; QuantizeMultiplierSmallerThanOne(double_multiplier, quantized_multiplier, &shift); shift = -shift; @@ -49,6 +50,29 @@ void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multipl } } +void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, + int *right_shift) { + int shift = 0; + const uint32_t scale_bits = (uint32_t)(double_multiplier); + /* multipiler is in[0x40000000, 0x7FFFFF80] range */ + *quantized_multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + if (quantized_multiplier[0] < INT32_C(0x40000000) || quantized_multiplier[0] > INT32_C(0x7FFFFF80)) { + printf("quantized multiplier must be in [0x40000000, 0x7FFFFF80] range, now multiplier is %d\n", + quantized_multiplier[0]); + return; + } + /* shift is in [0, 31] range */ + shift = 127 + 31 - 32 - ((uint32_t)(double_multiplier) >> 23); + shift = -shift; + if (shift < 0) { + *left_shift = 0; + *right_shift = shift; + } else { + *left_shift = shift; + *right_shift = 0; + } +} + uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } diff --git a/mindspore/lite/nnacl/quantization/quantize.h b/mindspore/lite/nnacl/quantization/quantize.h index 0d04d9aa13..971a533642 100644 --- a/mindspore/lite/nnacl/quantization/quantize.h +++ b/mindspore/lite/nnacl/quantization/quantize.h @@ -34,6 +34,8 @@ typedef struct QuantArg { } QuantArg; typedef struct ConvQuantArg { + RoundingMode round_mode_; + CalFixedMultiplierMode quant_multiplier_mode_; QuantArg *input_quant_args_; QuantArg *filter_quant_args_; QuantArg *output_quant_args_; @@ -46,7 +48,6 @@ typedef struct ConvQuantArg { size_t input_arg_num_; size_t filter_arg_num_; size_t output_arg_num_; - uint8_t asymmetric_; uint8_t per_channel_; } ConvQuantArg; @@ -282,7 +283,11 @@ void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int *right_shift); -void QuantizeRoundParameter(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, int *right_shift); +void QuantizeRoundParameterWithDoublePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, + int *right_shift); + +void QuantizeRoundParameterWithSinglePrecision(double double_multiplier, int32_t *quantized_multiplier, int *left_shift, + int *right_shift); uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp); diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index b7a851c5b5..ca934c6ccf 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -40,6 +40,8 @@ table QuantParam { varCorr: float = 1; meanCorr: float = 0; dstDtype: int = 32; + roundType: int = 1; + multiplier: int = -1; // calculate fixed point multiplier method } table Tensor { diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 7ddc0296e1..8e01ac7ced 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -69,6 +69,9 @@ void LiteSession::ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lit quant_arg.var_corr = quant_params->Get(j)->varCorr(); quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); quant_arg.inited = quant_params->Get(j)->inited(); + quant_arg.roundType = quant_params->Get(j)->roundType(); + quant_arg.multiplier = quant_params->Get(j)->multiplier(); + quant_arg.dstDtype = quant_params->Get(j)->dstDtype(); dst_tensor->AddQuantParam(quant_arg); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index f7e6cc9262..c7771f80be 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -261,12 +261,43 @@ int ConvolutionBaseCPUKernel::SetQuantMultiplier() { static_cast(conv_quant_arg_->input_quant_args_[0].scale_ * conv_quant_arg_->filter_quant_args_[i].scale_); double real_multiplier = in_scale / static_cast(conv_quant_arg_->output_quant_args_[0].scale_); conv_quant_arg_->real_multiplier_[i] = real_multiplier; - QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i], - &conv_quant_arg_->right_shift_[i]); + if (conv_quant_arg_->quant_multiplier_mode_ == Method_SinglePrecision) { + QuantizeRoundParameterWithSinglePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], + &conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]); + } else if (conv_quant_arg_->quant_multiplier_mode_ == Method_DoublePrecision) { + QuantizeRoundParameterWithDoublePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], + &conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]); + } } return RET_OK; } +void ConvolutionBaseCPUKernel::SetRoundingAndMultipilerMode() { + auto input_quant_arg = in_tensors_.at(kInputIndex)->quant_params().front(); + int round_type = input_quant_arg.roundType; + switch (round_type) { + case 1: + conv_quant_arg_->round_mode_ = Rounding_Away_from_zero; + break; + case 2: + conv_quant_arg_->round_mode_ = Rounding_Up; + break; + default: + conv_quant_arg_->round_mode_ = Rounding_No; + } + int cal_multiplier_type = input_quant_arg.multiplier; + switch (cal_multiplier_type) { + case 0: + conv_quant_arg_->quant_multiplier_mode_ = Method_SinglePrecision; + break; + case 1: + conv_quant_arg_->quant_multiplier_mode_ = Method_DoublePrecision; + break; + default: + conv_quant_arg_->quant_multiplier_mode_ = Method_No; + } +} + int ConvolutionBaseCPUKernel::SetQuantParam() { auto ret = MallocQuantParam(); if (ret != RET_OK) { @@ -288,13 +319,12 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { MS_LOG(ERROR) << "Set Output Tensor Quant Param Failed."; return ret; } - ret = SetIfPerChannel(); if (ret != RET_OK) { MS_LOG(ERROR) << "Set if per tensor channel failed."; return ret; } - + SetRoundingAndMultipilerMode(); ret = SetQuantMultiplier(); if (ret != RET_OK) { MS_LOG(ERROR) << "Set Quant Multiplier Failed."; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h index 61779ec3e8..7e287fd224 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -53,6 +53,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel { int SetFilterTensorQuantParam(); int SetOutputTensorQuantParam(); int SetQuantMultiplier(); + void SetRoundingAndMultipilerMode(); int CheckResizeValid(); void FreeQuantParam(); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc index 677e1e02d9..411ecd8619 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc @@ -120,8 +120,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { ret = DoDequantizeInt8ToFp32(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, quant_arg.zeroPoint, num_unit_thread); } else if (src_dtype == TypeId::kNumberTypeFloat32 && dst_dtype == TypeId::kNumberTypeInt8) { + bool from_uint8_src = false; + if (quant_arg.dstDtype == TypeId::kNumberTypeUInt8) { + from_uint8_src = true; + } ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_ptr_ + thread_offset, quant_arg.scale, - quant_arg.zeroPoint, num_unit_thread); + quant_arg.zeroPoint, num_unit_thread, from_uint8_src); } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) { ret = Int8ToUInt8(int8_ptr_ + thread_offset, uint8_ptr_ + thread_offset, num_unit_thread); } else if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeFloat32) { @@ -138,8 +142,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) { input_quant_arg.scale, input_quant_arg.zeroPoint); if (ret) { auto output_quant_arg = out_tensors_.front()->quant_params().front(); + bool from_uint8_src = false; + if (quant_arg.dstDtype == TypeId::kNumberTypeUInt8) { + from_uint8_src = true; + } ret = DoQuantizeFp32ToInt8(float32_ptr_ + thread_offset, int8_out_ptr_ + thread_offset, output_quant_arg.scale, - output_quant_arg.zeroPoint, num_unit_thread); + output_quant_arg.zeroPoint, num_unit_thread, from_uint8_src); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc index 691b9e5aaa..4c06585bac 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc @@ -254,8 +254,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() { const double in_scale = static_cast(input_scale_[i] * weight_scale_[i]); double real_multiplier = in_scale / static_cast(output_scale_[i]); conv_quant_arg_->real_multiplier_[i] = real_multiplier; - QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i], - &conv_quant_arg_->right_shift_[i]); + QuantizeRoundParameterWithDoublePrecision(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], + &conv_quant_arg_->left_shift_[i], &conv_quant_arg_->right_shift_[i]); } // now only consider per tensor for output diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index 5870fd5618..ba8cbcf420 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -132,8 +132,8 @@ int FullconnectionInt8CPUKernel::Init() { for (int i = 0; i < weight_quant_num; ++i) { const double in_scale = static_cast(quant_.input_.scale_ * quant_.filter_scale_[i]); double real_multiplier = in_scale / static_cast(quant_.output_.scale_); - QuantizeRoundParameter(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i], - &quant_.right_shift_[i]); + QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_.quant_multiplier_[i], &quant_.left_shift_[i], + &quant_.right_shift_[i]); } CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6, diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index d3b2a34227..42d6ddce2f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -138,8 +138,8 @@ int MatmulInt8CPUKernel::ReSize() { } } double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_; - QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, - &quant_params_.right_shift); + QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift, + &quant_params_.right_shift); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc index 848c57e8a2..494539b4a4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/relux_int8.cc @@ -39,7 +39,8 @@ int ReluXInt8CPUKernel::Init() { quant_arg_.output_arg.zp_ = output->quant_params().front().zeroPoint; const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_; - QuantizeRoundParameter(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_, &quant_arg_.right_shift_); + QuantizeRoundParameterWithDoublePrecision(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_, + &quant_arg_.right_shift_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc index 969032b4d9..528e2f3919 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/resize_int8.cc @@ -86,8 +86,8 @@ int ResizeInt8CPUKernel::Init() { quant_out_->zp_ = output->quant_params().front().zeroPoint; quant_out_->scale_ = output->quant_params().front().scale; - QuantizeRoundParameter(quant_in_->scale_ / quant_out_->scale_, &multiplier_->multiplier_, &multiplier_->left_shift_, - &multiplier_->right_shift_); + QuantizeRoundParameterWithDoublePrecision(quant_in_->scale_ / quant_out_->scale_, &multiplier_->multiplier_, + &multiplier_->left_shift_, &multiplier_->right_shift_); if (!InferShapeDone()) { return RET_OK; } diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index 69e44494ad..555641920d 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -38,6 +38,9 @@ struct QuantArg { bool inited; std::vector clusters{}; int bitNum; + int roundType; + int multiplier; + int dstDtype; }; class Tensor : public mindspore::tensor::MSTensor { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc index e28f2b2e1b..f1cd13aa25 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/matmul_int8_tests.cc @@ -118,7 +118,7 @@ TEST_F(TestMatmulInt8, simple) { int a_sums[ROW4] = {0}; int bias[COL4] = {0}; int multiplier, ls, rs; - QuantizeRoundParameter(1.0f, &multiplier, &ls, &rs); + QuantizeRoundParameterWithDoublePrecision(1.0f, &multiplier, &ls, &rs); #ifdef ENABLE_ARM64 MatmulInt8Neon64(a_r4x16, b_c16x4, output, ROW4, COL4, DEPTH16, a_sums, bias, INT8_MIN, INT8_MAX, 0, &multiplier, &ls, &rs, ROW, COL, COL, false); diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 174cd86f02..998ebfcf3e 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -121,6 +121,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me std::make_unique(input_quant_param); MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale << " zp: " << input_quant_param_ptr->zeroPoint; + input_quant_param_ptr->dstDtype = tensor_input->dataType; tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); } } @@ -151,6 +152,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me std::make_unique(channel_quant_param); MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale << " zp: " << output_quant_param_ptr->zeroPoint; + output_quant_param_ptr->dstDtype = output_tensor->dataType; output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr)); } } @@ -258,6 +260,9 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu auto subgraph_name = func_graph->get_attr("graph_name"); MS_ASSERT(subgraph_name != nullptr); sub_graphT->name = GetValue(subgraph_name); + auto fmk = func_graph->get_attr("fmk"); + MS_ASSERT(fmk != nullptr); + meta_graphT->fmkType = GetValue(fmk); auto cnodes = func_graph->GetOrderedCnodes(); for (const auto &cnode : cnodes) { diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc index 09c29174dc..6cf1ed5851 100644 --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -448,6 +448,8 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si toAddTensor->dataType = prim->dstT; if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { preTensor->quantParams.front()->zeroPoint += 128; + } else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) { + toAddTensor->quantParams.front()->zeroPoint += 128; } } graphT->allTensors.emplace_back(std::move(toAddTensor)); @@ -491,6 +493,8 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si toAddTensor->dataType = prim->dstT; if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { preTensor->quantParams.front()->zeroPoint += 128; + } else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) { + toAddTensor->quantParams.front()->zeroPoint += 128; } } graphT->allTensors.emplace_back(std::move(toAddTensor)); @@ -552,8 +556,10 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz MS_ASSERT(prim != nullptr); postTensor->dataType = prim->srcT; toAddTensor->dataType = prim->dstT; - if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) { + if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) { toAddTensor->quantParams.front()->zeroPoint += 128; + } else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { + postTensor->quantParams.front()->zeroPoint += 128; } } graphT->allTensors.emplace_back(std::move(toAddTensor)); @@ -624,6 +630,8 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz toAddTensor->dataType = prim->dstT; if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) { toAddTensor->quantParams.front()->zeroPoint += 128; + } else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) { + postTensor->quantParams.front()->zeroPoint += 128; } } graphT->allTensors.emplace_back(std::move(toAddTensor)); diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc index 389ab8bb62..b7443ad034 100644 --- a/mindspore/lite/tools/common/tensor_util.cc +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -38,6 +38,8 @@ std::unique_ptr CopyQuantParamT(const std::unique_ptrmax = srcQuantParam->max; dstQuantParam->narrowRange = srcQuantParam->narrowRange; dstQuantParam->numBits = srcQuantParam->numBits; + dstQuantParam->dstDtype = srcQuantParam->dstDtype; + dstQuantParam->multiplier = srcQuantParam->multiplier; return dstQuantParam; } diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index ce5ba7829d..9d4a665b0a 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -71,6 +71,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { return nullptr; } graph->set_attr("graph_name", MakeValue("main_graph")); + graph->set_attr("fmk", MakeValue(static_cast(converter::FmkType_MS))); } else { MS_ASSERT(nullptr != modelParser); const std::string modelFile = flag->modelFile; diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 2327b17122..145f5662da 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -158,7 +158,7 @@ int Flags::Init(int argc, const char **argv) { return RET_INPUT_PARAM_INVALID; } - if (this->trainModel == true) { + if (this->trainModel) { if (this->fmk != FmkType_MS) { std::cerr << "INPUT ILLEGAL: train model convertor supporting only MINDIR format"; return RET_INPUT_PARAM_INVALID; diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 22fb2f0fb8..1c1d6fea34 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -30,7 +30,14 @@ using mindspore::schema::QuantType_PostTraining; using mindspore::schema::QuantType_QUANT_NONE; using mindspore::schema::QuantType_WeightQuant; namespace converter { -enum FmkType { FmkType_TF = 0, FmkType_CAFFE = 1, FmkType_ONNX = 2, FmkType_MS = 3, FmkType_TFLITE = 4 }; +enum FmkType { + FmkType_TF = 0, + FmkType_CAFFE = 1, + FmkType_ONNX = 2, + FmkType_MS = 3, + FmkType_TFLITE = 4, + FmkType_ONNX_LOW_VERSION = 5 +}; class Flags : public virtual mindspore::lite::FlagParser { public: diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index d1fa82617d..1e149e2106 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -161,7 +161,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { if (postTensor->dataType != TypeId::kNumberTypeInt8) { continue; } - iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeInt8, &status); + iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeUInt8, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc index 20b1c12051..2cf9006bcd 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc @@ -25,7 +25,6 @@ STATUS SetUnusedQuantParamToDefaultPass::Run(schema::MetaGraphT *graph) { quant_param->min = 0; quant_param->max = 0; quant_param->narrowRange = true; - quant_param->dstDtype = TypeId::kNumberTypeFloat32; } } return RET_OK; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc index 1464086811..521c35d729 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc @@ -44,7 +44,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { } } } - int index = -1; + unsigned int index = -1; for (auto &tensor : graph->allTensors) { index++; if (tensor->quantParams.empty() || !tensor->quantParams.front()->inited) { @@ -59,7 +59,8 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { auto &quantParam = tensor->quantParams.front(); size_t wShapeSize = tensor->data.empty() ? 0 : GetShapeSize(*(tensor.get())); void *oriWeightData = tensor->data.data(); - if (quantParam->dstDtype == TypeId::kNumberTypeInt8) { + if (quantParam->dstDtype == TypeId::kNumberTypeUInt8 || quantParam->dstDtype == TypeId::kNumberTypeFloat32 || + quantParam->dstDtype == TypeId::kNumberTypeFloat) { std::vector qDatas(wShapeSize); auto weightQauntParam = GetTensorQuantParam(tensor); if (tensor->dataType == TypeId::kNumberTypeFloat || @@ -71,7 +72,7 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { for (size_t j = 0; j < wShapeSize; j++) { qDatas[j] = quant::QuantizeData(weightData[j], weightQauntParam.get()); } - } else { // tflite awareing quant + } else { // convert uint8 to int8 auto *weightData = static_cast(oriWeightData); for (size_t j = 0; j < wShapeSize; j++) { qDatas[j] = (int32_t)weightData[j] - 128; diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index c71c64aa59..6860ac7bdb 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -55,6 +55,7 @@ FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::s return nullptr; } func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); + func_graph_ptr_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_CAFFE))); return func_graph_ptr_; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 97dd4bbccf..dcd830a086 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -40,6 +40,7 @@ std::set SPECIAL_NODE = {"Gemm"}; FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) { NoSupportOp::GetInstance()->SetFmkType("ONNX"); + anf_root_graph_ = std::make_shared(); auto status = InitOriginModel(model_file); if (RET_OK != status) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -47,7 +48,6 @@ FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::st return nullptr; } - anf_root_graph_ = std::make_shared(); status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node"); if (RET_OK != status) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -77,6 +77,11 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { } OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version()); onnx_root_graph_ = onnx_model_.graph(); + if (OnnxNodeParser::opset_version() > 15) { + anf_root_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_ONNX))); + } else { + anf_root_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_ONNX_LOW_VERSION))); + } return RET_OK; } STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph, @@ -614,6 +619,9 @@ STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_na std::vector *quant_params) { quant_params->clear(); auto quant_param = std::make_unique(); + if (OnnxNodeParser::opset_version() <= 15) { + quant_param->multiplier = 0; + } std::string quant_tensor_name = "scale_" + tensor_name; auto status = CopyTensorQuantParam(quant_tensor_name, quant_param.get(), true); if (status != RET_OK) { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 5efa7fab56..bb73cd0332 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -366,6 +366,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin return nullptr; } anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); + anf_root_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_TF))); for (int i = 0; i < tf_root_graph_->node_size(); i++) { auto &node_def = tf_root_graph_->node(i); @@ -441,6 +442,7 @@ STATUS TFModelParser::ConvertSubgraph() { FuncGraphPtr sub_func_graph = std::make_shared(); sub_func_graph->set_attr("graph_name", MakeValue(sub_graph_name)); + sub_func_graph->set_attr("fmk", MakeValue(static_cast(converter::FmkType_TF))); std::unordered_map anf_sub_node_map; // convert sub graph inputs std::vector sub_graph_inputs; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 80d323de2e..f19cf8e0d6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -55,6 +55,7 @@ FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std:: return nullptr; } func_graph_ = std::make_shared(); + func_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_TFLITE))); auto status = ConvertGraphInputs(); if (status != RET_OK) { @@ -183,7 +184,7 @@ STATUS TfliteModelParser::ConvertOps() { } STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tensor, - std::vector *quant_params) { + std::vector *quant_params, int round_type) { if (tflite_tensor == nullptr) { MS_LOG(ERROR) << "tflite_tensor is null, set tensor quant params failed."; return RET_NULL_PTR; @@ -221,6 +222,8 @@ STATUS TfliteModelParser::SetTensorQuantParam(const tflite::TensorT *tflite_tens quant_param->max = tflite_tensor->quantization->max[i]; } quant_param->inited = true; + quant_param->roundType = round_type; + quant_param->multiplier = 1; quant_params->emplace_back(*std::move(quant_param)); } return RET_OK; @@ -236,6 +239,11 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; return RET_NULL_PTR; } + + int round_type = 1; + if (primitive_c->primitiveT()->value.type == PrimitiveType_Conv2D) { + round_type = 2; + } const auto &tflite_subgraph = tflite_model_->subgraphs.front(); for (auto input_idx : op->inputs) { if (input_idx < 0) { @@ -243,7 +251,7 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite } const auto &input_tensor = tflite_subgraph->tensors[input_idx]; std::vector quant_params; - auto status = SetTensorQuantParam(input_tensor.get(), &quant_params); + auto status = SetTensorQuantParam(input_tensor.get(), &quant_params, round_type); if (status != RET_OK) { MS_LOG(ERROR) << "set input tensor quant param failed."; return status; @@ -256,7 +264,7 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const tflite::OperatorT *op, lite } const auto &output_tensor = tflite_subgraph->tensors.at(output_idx); std::vector quant_params; - auto status = SetTensorQuantParam(output_tensor.get(), &quant_params); + auto status = SetTensorQuantParam(output_tensor.get(), &quant_params, round_type); if (status != RET_OK) { MS_LOG(ERROR) << "set output tensor quant param failed."; return status; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 3c69b9aa05..646b5b41c3 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -48,7 +48,8 @@ class TfliteModelParser : public ModelParser { STATUS ConvertOps(); STATUS ConvertGraphInputs(); STATUS ConvertGraphOutputs(); - static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector *quant_params); + static STATUS SetTensorQuantParam(const tflite::TensorT *tflite_tensor, std::vector *quant_params, + int round_type = 1); }; } // namespace mindspore::lite #endif // LITE_TFLITE_MODEL_PARSER_H diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index aa53b6c743..825fab53ca 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -595,6 +595,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, stru quant_param.numBits = bit_num; quant_param.narrowRange = false; quant_param.inited = true; + quant_param.roundType = 1; + quant_param.multiplier = 1; std::vector quant_params = {quant_param}; lite_primitive->AddInputQuantParam(quant_params); return RET_OK; @@ -612,6 +614,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct quant_param.numBits = bit_num; quant_param.narrowRange = false; quant_param.inited = true; + quant_param.roundType = 1; + quant_param.multiplier = 1; std::vector quant_params = {quant_param}; lite_primitive->AddOutputQuantParam(quant_params); return RET_OK;