From eac18afa6e598fd08ad6870bcd92682d1bfa216d Mon Sep 17 00:00:00 2001 From: xutianchun Date: Mon, 2 Nov 2020 17:03:50 +0800 Subject: [PATCH] change the schema of quant_params to reduce model size --- mindspore/lite/schema/model.fbs | 6 +++--- mindspore/lite/src/lite_session.cc | 14 ++++++++------ .../lite/src/runtime/kernel/arm/base/dequant.h | 3 ++- mindspore/lite/src/tensor.cc | 4 ++++ mindspore/lite/src/tensor.h | 9 +++++++-- .../tools/converter/quantizer/quantize_util.cc | 1 - .../lite/tools/converter/quantizer/quantize_util.h | 8 ++++---- 7 files changed, 28 insertions(+), 17 deletions(-) diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 98bc3f5e81..fe92596f69 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -32,10 +32,9 @@ table QuantParam { narrowRange: bool = true; numBits: int = 8; inited: bool = false; - varCorr: double = 1; - meanCorr: double = 0; + varCorr: float = 1; + meanCorr: float = 0; dstDtype: int = 32; - clusters: [float]; } table Tensor { @@ -49,6 +48,7 @@ table Tensor { offset: int; data: [ubyte]; quantParams: [QuantParam]; + quantClusters: [float]; } union PrimitiveType { diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 5658efd3cb..82ec0d8147 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -107,15 +107,17 @@ int LiteSession::ConvertTensors(const lite::Model *model) { quant_arg.var_corr = quant_params->Get(j)->varCorr(); quant_arg.mean_corr = quant_params->Get(j)->meanCorr(); quant_arg.inited = quant_params->Get(j)->inited(); - auto quant_clusters = quant_params->Get(j)->clusters(); - if (quant_clusters != nullptr) { - for (size_t k = 0; k < quant_clusters->size(); k++) { - quant_arg.clusters.emplace_back(quant_clusters->Get(k)); - } - } dstTensor->AddQuantParam(quant_arg); } } + auto quant_clusters = srcTensor->quantClusters(); + if (quant_clusters != nullptr) { + std::vector clusters; + for (size_t j = 0; j < quant_clusters->size(); j++) { + clusters.push_back(quant_clusters->Get(j)); + } + dstTensor->SetQuantClusters(clusters); + } this->tensors_.emplace_back(dstTensor); } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h b/mindspore/lite/src/runtime/kernel/arm/base/dequant.h index 934ee09cb5..f27f70c970 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/dequant.h @@ -79,11 +79,12 @@ class DequantUtil { } } else { auto quant_param = input_tensor->GetQuantParams(); + auto quant_clusters = input_tensor->GetQuantClusters(); auto param = quant_param.front(); auto scale = param.scale; auto zero_point = param.zeroPoint; for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) { - if (param.clusters.size() != 0) { + if (!quant_clusters.empty()) { int8_t index = quant_datas[j]; if (index > INT8_MAX || index < INT8_MIN) { MS_LOG(ERROR) << "KMeans param quant is error."; diff --git a/mindspore/lite/src/tensor.cc b/mindspore/lite/src/tensor.cc index 88c1bf9591..e78e5ec969 100644 --- a/mindspore/lite/src/tensor.cc +++ b/mindspore/lite/src/tensor.cc @@ -367,6 +367,10 @@ void Tensor::AddQuantParam(const QuantArg &quant_arg) { this->quant_params_.push std::vector Tensor::GetQuantParams() const { return this->quant_params_; } +std::vector Tensor::GetQuantClusters() const { return this->quant_clusters_; } + +void Tensor::SetQuantClusters(const std::vector &clusters) { this->quant_clusters_ = clusters; } + std::vector TensorVectorCast(const std::vector &src) { std::vector target(src.size()); std::transform(src.begin(), src.end(), target.begin(), [](Tensor *t) { return dynamic_cast(t); }); diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index 93bfe75ec3..370c1ef600 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -33,8 +33,8 @@ namespace lite { struct QuantArg { double scale; int32_t zeroPoint; - double var_corr{1}; - double mean_corr{0}; + float var_corr{1}; + float mean_corr{0}; bool inited; std::vector clusters{}; }; @@ -119,6 +119,10 @@ class Tensor : public mindspore::tensor::MSTensor { std::vector GetQuantParams() const; + std::vector GetQuantClusters() const; + + void SetQuantClusters(const std::vector &clusters); + bool IsConst(); bool IsScalar(); @@ -138,6 +142,7 @@ class Tensor : public mindspore::tensor::MSTensor { Category category_; size_t ref_count_ = 0; std::vector quant_params_; + std::vector quant_clusters_; mindspore::lite::Allocator *allocator_ = nullptr; }; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index d9cad086c5..fbe16f5839 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -449,7 +449,6 @@ std::vector KMeans(float *data, size_t elem_count, size_t k, size_t epoc error = error_cur; } // update data - quantParam->clusters = clusters; return clusters_index; } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 2afe92f07f..3e6dd76c6b 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -130,7 +130,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan } template STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitive_c, QuantType quantType, - int quant_max, int quant_min, size_t bitNum, bool per_channel) { + int quant_max, int quant_min, size_t bitNum, bool per_channel, bool k_means = false) { auto dims = weight->tensor_shape(); auto op_type = (schema::PrimitiveType)primitive_c->Type(); if (per_channel) { @@ -208,7 +208,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti average_raw += raw_data; } } - if (quantType == QuantType_WeightQuant && quant_param.clusters.size() == 0) { + if (quantType == QuantType_WeightQuant && !k_means) { // mean average_dequant = average_dequant / one_filter_size; average_raw = average_raw / one_filter_size; @@ -256,7 +256,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti } schema::QuantParamT quant_param; - if (quant_param.clusters.size() == 0) { + if (!k_means) { STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); if (status != RET_OK) { MS_LOG(ERROR) << "CalQuantizationParams failed" << status; @@ -267,7 +267,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti // update data and datatype for (uint32_t i = 0; i < elem_count; i++) { float raw_data = raw_datas[i]; - if (quant_param.clusters.size() == 0) { + if (!k_means) { auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); quant_datas[i] = quant_data; }