From 561ebcca7ea8d6de4a2e32ec98531a1edb3e0c22 Mon Sep 17 00:00:00 2001 From: xutianchun Date: Wed, 23 Sep 2020 10:12:11 +0800 Subject: [PATCH] promote weight quantization precision --- mindspore/lite/schema/model.fbs | 2 + mindspore/lite/src/lite_kernel.cc | 22 ++++++---- mindspore/lite/src/lite_session.cc | 4 +- mindspore/lite/src/tensor.h | 2 + .../tools/converter/quantizer/quantize_util.h | 42 ++++++++++++++++++- 5 files changed, 61 insertions(+), 11 deletions(-) diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index fa885526e7..094d236c2a 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -32,6 +32,8 @@ table QuantParam { narrowRange: bool = true; numBits: int = 8; inited: bool = false; + var_corr: double = 1; + mean_corr: double = 0; } table Tensor { diff --git a/mindspore/lite/src/lite_kernel.cc b/mindspore/lite/src/lite_kernel.cc index 7e6752b26c..c20ce8fe9d 100644 --- a/mindspore/lite/src/lite_kernel.cc +++ b/mindspore/lite/src/lite_kernel.cc @@ -174,9 +174,9 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { MS_LOG(ERROR) << "no quant param"; return nullptr; } - const auto *quant_data = static_cast(input_tensor->MutableData()); - auto *dequant_data = static_cast(malloc(input_tensor->ElementsNum() * sizeof(float))); - if (dequant_data == nullptr) { + const auto *quant_datas = static_cast(input_tensor->MutableData()); + auto *dequant_datas = static_cast(malloc(input_tensor->ElementsNum() * sizeof(float))); + if (dequant_datas == nullptr) { MS_LOG(ERROR) << "malloc faile"; return nullptr; } @@ -185,7 +185,7 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { size_t channels = static_cast(input_tensor->Batch()); if (input_tensor->GetQuantParams().size() != channels) { MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; - free(dequant_data); + free(dequant_datas); return nullptr; } size_t per_channel_size = input_tensor->ElementsNum() / channels; @@ -194,9 +194,15 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { auto param = quant_param.at(i); auto scale = param.scale; auto zero_point = param.zeroPoint; + auto var_corr = param.var_corr; + auto mean_corr = param.mean_corr; + if (var_corr < 0 || var_corr > 10) { + MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr; + var_corr = 1; + } for (size_t j = 0; j < per_channel_size; j++) { - dequant_data[per_channel_size * i + j] = - static_cast((quant_data[per_channel_size * i + j] - zero_point) * scale); + auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale; + dequant_datas[per_channel_size * i + j] = static_cast(dequant_data * var_corr + mean_corr); } } } else { @@ -205,9 +211,9 @@ float *LiteKernelUtil::DequantWeight(lite::Tensor *input_tensor) { auto scale = param.scale; auto zero_point = param.zeroPoint; for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) { - dequant_data[j] = static_cast((quant_data[j] - zero_point) * scale); + dequant_datas[j] = static_cast((quant_datas[j] - zero_point) * scale); } } - return dequant_data; + return dequant_datas; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 9cf9836e2f..2b971837d9 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -106,6 +106,8 @@ int LiteSession::ConvertTensors(const lite::Model *model) { QuantArg quant_arg{}; quant_arg.scale = quant_params->Get(j)->scale(); quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); + quant_arg.var_corr = quant_params->Get(j)->var_corr(); + quant_arg.mean_corr = quant_params->Get(j)->mean_corr(); dstTensor->AddQuantParam(quant_arg); } } @@ -351,7 +353,7 @@ int LiteSession::Init(Context *context) { } } #endif - executor = new(std::nothrow) Executor(); + executor = new (std::nothrow) Executor(); if (nullptr == executor) { MS_LOG(ERROR) << "New Executor failed"; is_running_.store(false); diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index 0fa96d3c03..c75f79f9bd 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -33,6 +33,8 @@ namespace lite { struct QuantArg { double scale; int32_t zeroPoint; + double var_corr{1}; + double mean_corr{0}; }; class Tensor : public mindspore::tensor::MSTensor { diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index eecf225acf..651da849a4 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -143,7 +143,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti return RET_ERROR; } std::vector quant_datas(elem_count); - + std::vector dequant_datas(elem_count); if (per_channel) { // notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC // channel at first @@ -173,8 +173,9 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti MS_LOG(ERROR) << "CalQuantizationParams failed" << status; return status; } - quant_params.emplace_back(quant_param); // do quantization + double average_dequant = 0; + double average_raw = 0; for (uint32_t j = 0; j < one_filter_size; j++) { auto index = j + i * one_filter_size; if (index >= elem_count) { @@ -184,7 +185,44 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti float raw_data = raw_datas[index]; auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); quant_datas[index] = quant_data; + + if (quantType == QuantType_WeightQuant) { + float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint); + dequant_datas[index] = dequant_data; + average_dequant += dequant_data; + average_raw += raw_data; + } } + if (quantType == QuantType_WeightQuant) { + // mean + average_dequant = average_dequant / one_filter_size; + average_raw = average_raw / one_filter_size; + // std + double variance_dequant = 0; + double variance_raw = 0; + for (uint32_t j = 0; j < one_filter_size; j++) { + auto index = j + i * one_filter_size; + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + variance_dequant += std::pow(dequant_datas[index] - average_dequant, 2); + variance_raw += std::pow(raw_datas[index] - average_raw, 2); + } + variance_dequant = std::sqrt(variance_dequant / one_filter_size); + variance_raw = std::sqrt(variance_raw / one_filter_size); + quant_param.var_corr = 1; + if (variance_raw != 0 && variance_dequant != 0) { + auto temp_var_corr = variance_raw / variance_dequant; + if (temp_var_corr > 0 && temp_var_corr < 10) { + quant_param.var_corr = temp_var_corr; + } else { + MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr; + } + } + quant_param.mean_corr = average_raw - average_dequant * quant_param.var_corr; + } + quant_params.emplace_back(quant_param); } auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); if (ret != EOK) {