From 0f2c78253eefb6f9cbf45d314c29672e99e85851 Mon Sep 17 00:00:00 2001 From: kai00 Date: Mon, 31 Aug 2020 20:33:56 +0800 Subject: [PATCH] ====weight quant====== --- mindspore/lite/src/model.cc | 5 +- mindspore/lite/src/ops/primitive_c.cc | 4 + mindspore/lite/src/ops/primitive_c.h | 4 + .../kernel/arm/base/convolution_base.cc | 42 +++++ .../kernel/arm/base/convolution_base.h | 1 + .../runtime/kernel/arm/fp32/convolution.cc | 12 ++ .../kernel/arm/fp32/convolution_depthwise.cc | 13 ++ .../lite/tools/anf_exporter/anf_exporter.cc | 5 +- .../lite/tools/converter/anf_transform.cc | 34 ++-- .../lite/tools/converter/converter_flags.cc | 2 + .../graph/weight_format_hardcode_pass.cc | 1 + .../graph/weight_format_transform_pass.cc | 2 +- .../tools/converter/quantizer/CMakeLists.txt | 1 + .../quantizer/post_training_quantizer.cc | 3 +- .../converter/quantizer/quantize_util.cc | 165 ----------------- .../tools/converter/quantizer/quantize_util.h | 169 +++++++++++++++++- .../converter/quantizer/weight_quantizer.cc | 148 +++++++++++++++ .../converter/quantizer/weight_quantizer.h | 53 ++++++ 18 files changed, 479 insertions(+), 185 deletions(-) create mode 100644 mindspore/lite/tools/converter/quantizer/weight_quantizer.cc create mode 100644 mindspore/lite/tools/converter/quantizer/weight_quantizer.h diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index bafb15556b..49b9b49203 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -103,8 +103,9 @@ int ModelImpl::BuildOps() { auto cNode = meta_graph_->nodes()->GetAs(i); auto name = cNode->name()->str(); auto srcPrim = cNode->primitive(); - - this->ops_[name] = PrimitiveC::UnPackFromSchemaPrimitive(const_cast(srcPrim)); + auto prim = PrimitiveC::UnPackFromSchemaPrimitive(const_cast(srcPrim)); + prim->SetQuantType(cNode->quantType()); + this->ops_[name] = prim; } return 0; } diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 9ae43c7a86..09ba606d8c 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -688,6 +688,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primi } return nullptr; } +void PrimitiveC::SetQuantType(schema::QuantType quant_type) { + this->quant_type_ = quant_type; +} +schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_;} #endif int PrimitiveC::Type() const { diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index d13f5c31a2..73ba1f4d87 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -145,6 +145,9 @@ class PrimitiveC { int Type() const; + void SetQuantType(schema::QuantType quant_type); + schema::QuantType GetQuantType() const; + protected: template ::value>> static PrimitiveC *NewPrimitiveC(const schema::Primitive *primitive) { @@ -194,6 +197,7 @@ class PrimitiveC { const schema::Primitive *primitive_ = nullptr; char *primitive_buf_ = nullptr; bool infer_flag_ = true; + schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; }; #endif } // namespace lite diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index 6e7388801b..c5485ca69b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -331,4 +331,46 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { return RET_OK; } +int ConvolutionBaseCPUKernel::RestoreFilter(lite::tensor::Tensor *input_tensor) { + MS_ASSERT(input_tensor != nullptr); + if (input_tensor->GetQuantParams().empty()) { + MS_LOG(ERROR) << "no quant param"; + return RET_ERROR; + } + const auto* quant_data = static_cast(input_tensor->Data()); + auto* dequant_data = static_cast(malloc(input_tensor->DataSize() * sizeof(float))); + if (dequant_data == nullptr) { + MS_LOG(ERROR) << "malloc faile"; + return RET_ERROR; + } + + if (input_tensor->GetQuantParams().size() != kPerTensor) { + size_t channels = static_cast(input_tensor->Batch()); + if (input_tensor->GetQuantParams().size() != channels) { + MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; + return RET_ERROR; + } + size_t per_channel_size = input_tensor->DataSize() / channels; + auto quant_param = input_tensor->GetQuantParams(); + for (size_t i = 0; i < channels; i++) { + auto param = quant_param.at(i); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (size_t j = 0; j < per_channel_size; j++) { + dequant_data[per_channel_size * i + j] = static_cast( + (quant_data[per_channel_size * i + j] - zero_point) * scale); + } + } + } else { + auto quant_param = input_tensor->GetQuantParams(); + auto param = quant_param.front(); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (int64_t j = 0; j < input_tensor->DataSize(); j++) { + dequant_data[j] = static_cast((quant_data[j] - zero_point) * scale); + } + } + input_tensor->SetData(dequant_data); + return RET_OK; +} } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h index d39400eda8..9dd5ba5966 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.h @@ -60,6 +60,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel { int SetQuantMultiplier(); int CheckResizeValid(); void FreeQuantParam(); + static int RestoreFilter(lite::tensor::Tensor *input_tensor); protected: int tile_num_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index c5252fdc47..3375a12252 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -239,6 +239,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vectorData(); + if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); + } + kernel::LiteKernel *kernel; if (kernel_h == 1 && kernel_w == 1) { kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive); @@ -263,6 +269,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector(op_parameter->type_)); return nullptr; } + + if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } + return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc index 53ea4cf09f..7a09430880 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc @@ -131,6 +131,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vectorData(); + if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); + } + auto conv_param = reinterpret_cast(opParameter); kernel::LiteKernel *kernel; if (conv_param->input_channel_ < 32) { @@ -149,6 +156,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector(opParameter->type_)); return nullptr; } + + if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } + return kernel; } diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 07f4178664..e2b48c1a75 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -64,7 +64,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me MS_ASSERT(dst_node != nullptr); // add quant param dst_node->quantType = primitive->GetQuantType(); - if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining) { + if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining + || dst_node->quantType == schema::QuantType_WeightQuant) { MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; // activation auto input_quant_params = primitive->GetInputQuantParams(); @@ -103,7 +104,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me } } else { for (auto output_quant_param : output_quant_params[0]) { - if (tensor_output->quantParams.empty()) { + if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { std::unique_ptr output_quant_param_ptr = std::make_unique(output_quant_param); MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 7be4fb9a10..9f09cf4507 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -26,6 +26,7 @@ #include "tools/optimizer/fusion/constant_folding_fusion.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" +#include "tools/converter/quantizer/weight_quantizer.h" using std::string; namespace mindspore { @@ -57,11 +58,20 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver FuncGraphPtr new_graph = optimizer->Optimize(old_graph); // quant - if (config != nullptr && config->quantType == schema::QuantType_PostTraining) { - this->mQuantizer = std::make_unique(new_graph, config->configFile, 8); - if (mQuantizer == nullptr) { - MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; - return nullptr; + if (config != nullptr) { + if (config->quantType == schema::QuantType_PostTraining) { + this->mQuantizer = std::make_unique(new_graph, config->configFile, 8); + if (mQuantizer == nullptr) { + MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; + return nullptr; + } + } else if (config->quantType == schema::QuantType_WeightQuant) { + this->mQuantizer = std::make_unique(new_graph, config->quantSize, + config->convWeightQuantChannelThreshold, config->bitNum); + if (mQuantizer == nullptr) { + MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; + return nullptr; + } } } if (mQuantizer != nullptr) { @@ -71,12 +81,14 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver MS_LOG(ERROR) << "Quant failed " << status; return nullptr; } - quant::QuantCast quant_cast; - quant_cast.SetInputDataDType(kNumberTypeFloat32); - status = quant_cast.Run(new_graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "add QuantCast error"; - return nullptr; + if (config->quantType == schema::QuantType_PostTraining) { + quant::QuantCast quant_cast; + quant_cast.SetInputDataDType(kNumberTypeFloat32); + status = quant_cast.Run(new_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "add QuantCast error"; + return nullptr; + } } } diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index bf88c5ddcc..02cd3f0ca0 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -36,6 +36,8 @@ Flags::Flags() { AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); + AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", + "convWeightQuantChannelThreshold", "16"); AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc index 2279e6e1cb..c8d9812b43 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc @@ -191,6 +191,7 @@ STATUS WeightFormatHardCodePass::HardCodeTFLITE(const std::unique_ptr &n switch (this->quantType) { case QuantType_AwareTraining: case QuantType_PostTraining: + case QuantType_WeightQuant: case QuantType_QUANT_NONE: { if (opType == schema::PrimitiveType_Conv2D) { weightTensor->format = schema::Format_KHWC; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc index 870a70fff4..c88814dee9 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc @@ -31,7 +31,7 @@ void WeightFormatTransformPass::SetDstFormat(Format format) { this->dstFormat = STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) { MS_ASSERT(graph != nullptr); - if (this->quantType == QuantType_AwareTraining) { + if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_WeightQuant) { auto status = QuantDataFormatTrans(graph); if (status != RET_OK) { MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index 59082b4413..d0f7e6c7d3 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -11,6 +11,7 @@ add_library(quantizer_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc + ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc ) if(ENABLE_ASAN) diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 9d7a8b4478..3dbea54a94 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -530,7 +530,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr

(paramValue, primitive_c, QuantType_PostTraining, quant_max, + quant_min, bit_num, perchanel, depthwise); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed: " << status; return status; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 4c262792bb..6e07cd0710 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -279,171 +279,6 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl return RET_OK; } -STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitive_c, QuantType quantType, - int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { - auto dims = weight->tensor_shape(); - if (per_channel) { - if (dims.size() != 4) { - MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; - per_channel = false; - } else { - uint32_t channels = dims[0]; - if (channels == 0) { - MS_LOG(ERROR) << "channels is 0"; - return RET_ERROR; - } - } - } - - vector quant_params; - size_t elem_count = weight->tensor_shape_size(); - auto *raw_datas = static_cast(weight->tensor_addr()); - if (raw_datas == nullptr) { - MS_LOG(ERROR) << "rawDatas is nullptr"; - return RET_ERROR; - } - vector quant_datas(elem_count); - - if (per_channel) { - // notice: - // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D - // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK - if (depth_wise) { - // channel at last - auto channels = dims[3]; - if (channels == 0) { - MS_LOG(ERROR) << "channels is zero"; - return RET_ERROR; - } - size_t one_filter_size = elem_count / channels; - - for (int i = 0; i < channels; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - // find min and max - for (size_t j = 0; j < one_filter_size; j++) { - auto index = i + j * channels; - if (index >= elem_count) { - MS_LOG(ERROR) << "over flow!"; - return RET_ERROR; - } - min = std::min(min, raw_datas[index]); - max = std::max(max, raw_datas[index]); - } - schema::QuantParamT quant_param; - STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); - if (status != RET_OK) { - MS_LOG(ERROR) << "CalQuantizationParams failed" << status; - return status; - } - quant_params.emplace_back(quant_param); - // do quantization - for (uint32_t j = 0; j < one_filter_size; j++) { - auto index = i + j * channels; - if (index >= elem_count) { - MS_LOG(ERROR) << "over flow!"; - return RET_ERROR; - } - float raw_data = raw_datas[index]; - auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); - quant_datas[index] = quant_data; - } - } - auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), - elem_count * sizeof(int8_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy error: " << ret; - return RET_ERROR; - } - weight->set_tensor_size(elem_count * sizeof(int8_t)); - } else { - // channel at first - auto channels = dims[0]; - if (channels == 0) { - MS_LOG(ERROR) << "channels is zero"; - return RET_ERROR; - } - size_t one_filter_size = elem_count / channels; - - for (int i = 0; i < channels; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - // find min and max - for (size_t j = 0; j < one_filter_size; j++) { - auto index = j + i * one_filter_size; - if (index >= elem_count) { - MS_LOG(ERROR) << "over flow!"; - return RET_ERROR; - } - min = std::min(min, raw_datas[index]); - max = std::max(max, raw_datas[index]); - } - schema::QuantParamT quant_param; - STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); - if (status != RET_OK) { - MS_LOG(ERROR) << "CalQuantizationParams failed" << status; - return status; - } - quant_params.emplace_back(quant_param); - // do quantization - for (uint32_t j = 0; j < one_filter_size; j++) { - auto index = j + i * one_filter_size; - if (index >= elem_count) { - MS_LOG(ERROR) << "over flow!"; - return RET_ERROR; - } - float raw_data = raw_datas[index]; - auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); - quant_datas[index] = quant_data; - } - } - auto ret = - memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy error: " << ret; - return RET_ERROR; - } - weight->set_tensor_size(elem_count * sizeof(int8_t)); - } - - } else { - // per layer - float min = FLT_MAX; - float max = -FLT_MIN; - for (uint32_t i = 0; i < elem_count; i++) { - // find max min - min = std::min(min, raw_datas[i]); - max = std::max(max, raw_datas[i]); - } - - schema::QuantParamT quant_param; - STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); - if (status != RET_OK) { - MS_LOG(ERROR) << "CalQuantizationParams failed" << status; - return status; - } - quant_params.emplace_back(quant_param); - // update data and datatype - for (uint32_t i = 0; i < elem_count; i++) { - float raw_data = raw_datas[i]; - auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); - quant_datas[i] = quant_data; - } - auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy error: " << ret; - return RET_ERROR; - } - weight->set_tensor_size(elem_count * sizeof(int8_t)); - } - if (quant_params.empty()) { - MS_LOG(ERROR) << "quant_params empty"; - return RET_ERROR; - } - primitive_c->AddInputQuantParam(quant_params); - return RET_OK; -} - STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { auto *rawDatas = reinterpret_cast(weight); vector qDatas(rawDatas, rawDatas + shapeSize); diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 352f969c10..b955ab45f7 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include "tools/converter/quantizer/quantizer.h" #include "src/ops/primitive_c.h" #include "include/errorcode.h" @@ -117,10 +119,171 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan return static_cast(quant_data); }(); } - +template STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primitive_c, QuantType quantType, - int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false, - bool depth_wise = false); + int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { + auto dims = weight->tensor_shape(); + if (per_channel) { + if (dims.size() != 4) { + MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; + per_channel = false; + } else { + uint32_t channels = dims[0]; + if (channels == 0) { + MS_LOG(ERROR) << "channels is 0"; + return RET_ERROR; + } + } + } + + std::vector quant_params; + size_t elem_count = weight->tensor_shape_size(); + auto *raw_datas = static_cast(weight->tensor_addr()); + if (raw_datas == nullptr) { + MS_LOG(ERROR) << "rawDatas is nullptr"; + return RET_ERROR; + } + std::vector quant_datas(elem_count); + + if (per_channel) { + // notice: + // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D + // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK + if (depth_wise) { + // channel at last + auto channels = dims[3]; + if (channels == 0) { + MS_LOG(ERROR) << "channels is zero"; + return RET_ERROR; + } + size_t one_filter_size = elem_count / channels; + + for (int i = 0; i < channels; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + // find min and max + for (size_t j = 0; j < one_filter_size; j++) { + auto index = i + j * channels; + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + min = std::min(min, raw_datas[index]); + max = std::max(max, raw_datas[index]); + } + schema::QuantParamT quant_param; + STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalQuantizationParams failed" << status; + return status; + } + quant_params.emplace_back(quant_param); + // do quantization + for (uint32_t j = 0; j < one_filter_size; j++) { + auto index = i + j * channels; + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + float raw_data = raw_datas[index]; + auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); + quant_datas[index] = quant_data; + } + } + auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), + elem_count * sizeof(T)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + weight->set_tensor_size(elem_count * sizeof(T)); + } else { + // channel at first + auto channels = dims[0]; + if (channels == 0) { + MS_LOG(ERROR) << "channels is zero"; + return RET_ERROR; + } + size_t one_filter_size = elem_count / channels; + + for (int i = 0; i < channels; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + // find min and max + for (size_t j = 0; j < one_filter_size; j++) { + auto index = j + i * one_filter_size; + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + min = std::min(min, raw_datas[index]); + max = std::max(max, raw_datas[index]); + } + schema::QuantParamT quant_param; + STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalQuantizationParams failed" << status; + return status; + } + quant_params.emplace_back(quant_param); + // do quantization + for (uint32_t j = 0; j < one_filter_size; j++) { + auto index = j + i * one_filter_size; + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + float raw_data = raw_datas[index]; + auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); + quant_datas[index] = quant_data; + } + } + auto ret = + memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + weight->set_tensor_size(elem_count * sizeof(T)); + } + + } else { + // per layer + float min = FLT_MAX; + float max = -FLT_MIN; + for (uint32_t i = 0; i < elem_count; i++) { + // find max min + min = std::min(min, raw_datas[i]); + max = std::max(max, raw_datas[i]); + } + + schema::QuantParamT quant_param; + STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalQuantizationParams failed" << status; + return status; + } + quant_params.emplace_back(quant_param); + // update data and datatype + for (uint32_t i = 0; i < elem_count; i++) { + float raw_data = raw_datas[i]; + auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); + quant_datas[i] = quant_data; + } + auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + weight->set_tensor_size(elem_count * sizeof(T)); + } + if (quant_params.empty()) { + MS_LOG(ERROR) << "quant_params empty"; + return RET_ERROR; + } + primitive_c->AddInputQuantParam(quant_params); + return RET_OK; +} STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); } // namespace quant diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc new file mode 100644 index 0000000000..4badecb0e8 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/weight_quantizer.h" +#include +#include +#include +#include "src/common/common.h" +#include "ir/dtype/type_id.h" + +using std::string; +using std::vector; + +namespace mindspore { +namespace lite { +namespace quant { + +WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, + const std::string &convWeightChannelThreshold, const std::string &bitNum) + : Quantizer(graph) { + auto quantSize = static_cast(std::stoull(weightSize)); + this->bitNum = static_cast(std::stoull(bitNum)); + auto convQuantWeightChannelThreshold = static_cast(std::stoull(convWeightChannelThreshold)); + mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold)); +} + +STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { + for (auto &cnode : nodes) { + if (!mStrategy->CanConvOpQuantized(cnode)) { + continue; + } + + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return RET_ERROR; + } + + auto inputNode = cnode->input(2); + if (!inputNode->isa()) { + return RET_ERROR; + } + + auto paramNode = inputNode->cast(); + if (!paramNode->has_default()) { + return RET_ERROR; + } + + std::vector quant_params; + primitive_c->AddInputQuantParam(quant_params); + + auto op_type = (schema::PrimitiveType)primitive_c->Type(); + bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false; + + ParamValueLitePtr param_value = std::static_pointer_cast(paramNode->default_param()); + auto status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, 255, 0, + bitNum, true, depthwise); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + param_value->set_tensor_type(kNumberTypeUInt8); + primitive_c->SetQuantType(schema::QuantType_WeightQuant); + } + + return RET_OK; +} + +STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { + for (auto &node : nodes) { + if (!mStrategy->CanMulOpQuantized(node)) { + continue; + } + + ParamValueLitePtr param_value = nullptr; + for (size_t i = 1; i < node->size(); i++) { + auto inputNode = node->input(i); + if (inputNode->isa() == true) { + auto paramNode = inputNode->cast(); + if ((paramNode != nullptr) && (paramNode->has_default() == true)) { + param_value = std::static_pointer_cast(paramNode->default_param()); + if ((param_value == nullptr) || (param_value->tensor_size() == 0) + || (param_value->tensor_shape().size() != 4) + || (param_value->tensor_addr() == nullptr) + || (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) { + param_value = nullptr; + continue; + } else { + break; + } + } + } + } + if (param_value == nullptr) { + MS_LOG(ERROR) << "No valid input param node !"; + return RET_ERROR;; + } + + auto primitive_c = GetValueNode>(node->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr"; + return RET_ERROR; + } + + auto status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, 255, 0, bitNum, true, false); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + param_value->set_tensor_type(kNumberTypeUInt8); + primitive_c->SetQuantType(schema::QuantType_WeightQuant); + } + + return RET_OK; +} + +STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) { + auto ret = RET_OK; + auto cnodes = funcGraph->GetOrderedCnodes(); + ret = DoConvQuantize(cnodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; + return ret; + } + ret = DoMulQuantize(cnodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoMulQuantize failed :" << ret; + return ret; + } + return ret; +} +} // namespace quant +} // namespace lite +} // namespace mindspore + diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h new file mode 100644 index 0000000000..0726dd3df1 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef WEIGHT_QUANTIZER_H +#define WEIGHT_QUANTIZER_H + +#include +#include +#include +#include "tools/converter/quantizer/quantizer.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "include/model.h" +#include "base/base.h" +#include "abstract/dshape.h" + +namespace mindspore { +namespace lite { +namespace quant { +class WeightQuantizer : public Quantizer { + public: + WeightQuantizer(FuncGraphPtr graph, const std::string& weightSize, + const std::string& covWeightChannelThreshold, const std::string& bitNum); + + ~WeightQuantizer() = default; + + STATUS DoQuantize(FuncGraphPtr funcGraph) override; + STATUS DoConvQuantize(const std::list &nodes); + STATUS DoMulQuantize(const std::list &nodes); + + private: + std::unique_ptr mStrategy; + size_t bitNum; +}; +} // namespace quant +} // namespace lite +} // namespace mindspore +#endif +