From 252c13fedd51b23336928fede71677e3a95d0af7 Mon Sep 17 00:00:00 2001 From: xutianchun Date: Fri, 14 Aug 2020 19:12:08 +0800 Subject: [PATCH] fix post quantization --- .../src/common/anf_exporter/anf_exporter.cc | 30 +++++++++++-------- .../quantizer/post_training_quantizer.cc | 2 +- .../converter/quantizer/quantize_util.cc | 26 +++++++++++----- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index 31bd1b3222..3e6a10cf91 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -201,14 +201,16 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { } auto activate_index = node->inputIndex[i]; auto tensor_input = metaGraphT->allTensors[activate_index].get(); - std::unique_ptr input_quant_param = - std::make_unique(input_quant_params[i]); - MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale - << " zp: " << input_quant_param->zeroPoint; - tensor_input->quantParams.emplace_back(std::move(input_quant_param)); - if (!(node_type == schema::PrimitiveType_QuantDTypeCast && - primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) { - tensor_input->dataType = kNumberTypeInt8; + if (tensor_input->quantParams.empty()) { + std::unique_ptr input_quant_param = + std::make_unique(input_quant_params[i]); + MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale + << " zp: " << input_quant_param->zeroPoint; + tensor_input->quantParams.emplace_back(std::move(input_quant_param)); + if (!(node_type == schema::PrimitiveType_QuantDTypeCast && + primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) { + tensor_input->dataType = kNumberTypeInt8; + } } } @@ -219,11 +221,13 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { if (output_quant_params.empty()) { MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; } else { - std::unique_ptr output_quant_param = - std::make_unique(output_quant_params[0]); - MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale - << " zp: " << output_quant_param->zeroPoint; - tensor_output->quantParams.emplace_back(std::move(output_quant_param)); + if (tensor_output->quantParams.empty()) { + std::unique_ptr output_quant_param = + std::make_unique(output_quant_params[0]); + MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale + << " zp: " << output_quant_param->zeroPoint; + tensor_output->quantParams.emplace_back(std::move(output_quant_param)); + } } if (!(node_type == schema::PrimitiveType_QuantDTypeCast && primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 3e8840c4da..08fc4ce291 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -62,7 +62,7 @@ struct DivergInfo { this->bin_num = bins; this->bit_num = bits; histogram.resize(bin_num); - max = FLT_MIN; + max = -FLT_MAX; min = FLT_MAX; this->quant_max = quant_max; this->quant_min = quant_min; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index e03cbe79e9..e7e37b41e7 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -313,7 +313,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; per_channel = false; } else { - uint32_t channels = dims[3]; + uint32_t channels = dims[0]; if (channels == 0) { MS_LOG(ERROR) << "channels is 0"; return RET_ERROR; @@ -325,7 +325,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK size_t shapeSize = weightPtr->tensor_shape_size(); - auto channels = dims[3]; + auto channels = dims[0]; size_t oneFilterSize = shapeSize / channels; auto *rawDatas = reinterpret_cast(weightPtr->tensor_addr()); if (rawDatas == nullptr) { @@ -334,15 +334,20 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ } float min = FLT_MAX; - float max = FLT_MIN; + float max = -FLT_MAX; weightPtr->quant_param().clear(); vector qDatas(shapeSize); for (uint32_t i = 0; i < channels; i++) { // find min and max for (uint32_t j = 0; j < oneFilterSize; j++) { - min = std::min(min, rawDatas[i + j * oneFilterSize]); - max = std::max(max, rawDatas[i + j * oneFilterSize]); + auto index = j + i * channels; + if (index >= shapeSize) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + min = std::min(min, rawDatas[index]); + max = std::max(max, rawDatas[index]); } std::unique_ptr quantParam = std::unique_ptr(new AnfQuantParam); STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); @@ -350,11 +355,16 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ MS_LOG(ERROR) << "CalQuantizationParams failed" << status; return status; } - // update data and datatype + // do quantization for (uint32_t j = 0; j < oneFilterSize; j++) { - float rawData = rawDatas[i + j * oneFilterSize]; + auto index = j + i * channels; + if (index >= shapeSize) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + float rawData = rawDatas[index]; auto qData = QuantizeData(rawData, quantParam.get(), quant_max, quant_min); - qDatas[i + j * oneFilterSize] = qData; + qDatas[index] = qData; } weightPtr->set_quant_param(quantParam); }