From 636dd66839217db97751059391fdb1f94e7db1cb Mon Sep 17 00:00:00 2001 From: kai00 Date: Wed, 2 Sep 2020 19:37:02 +0800 Subject: [PATCH] mat mul weight quant --- .../kernel/arm/base/convolution_base.cc | 4 ++ .../kernel/arm/base/fullconnection_base.cc | 54 +++++++++++++++++ .../kernel/arm/base/fullconnection_base.h | 1 + .../runtime/kernel/arm/base/matmul_base.cc | 59 +++++++++++++++++++ .../src/runtime/kernel/arm/base/matmul_base.h | 1 + .../kernel/arm/fp32/arithmetic_self.cc | 54 +++++++++++++++++ .../runtime/kernel/arm/fp32/arithmetic_self.h | 1 + .../lite/src/runtime/kernel/arm/fp32/scale.cc | 54 +++++++++++++++++ .../lite/src/runtime/kernel/arm/fp32/scale.h | 1 + .../lite/tools/converter/converter_flags.cc | 1 + .../converter/quantizer/weight_quantizer.cc | 44 +++++++++++--- 11 files changed, 265 insertions(+), 9 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index c5485ca69b..84addcaa2f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -333,6 +333,10 @@ int ConvolutionBaseCPUKernel::SetQuantParam() { } int ConvolutionBaseCPUKernel::RestoreFilter(lite::tensor::Tensor *input_tensor) { MS_ASSERT(input_tensor != nullptr); + if (input_tensor->data_type() != kNumberTypeUInt8) { + MS_LOG(ERROR) << "conv weight input type error" << input_tensor->data_type(); + return RET_ERROR; + } if (input_tensor->GetQuantParams().empty()) { MS_LOG(ERROR) << "no quant param"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc index b33a03dceb..62dec5ad59 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -53,7 +53,52 @@ kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vectordata_type() != kNumberTypeUInt8) { + MS_LOG(ERROR) << "full connect input type error" << input_tensor->data_type(); + return RET_ERROR; + } + if (input_tensor->GetQuantParams().empty()) { + MS_LOG(ERROR) << "no quant param"; + return RET_ERROR; + } + const auto* quant_data = static_cast(input_tensor->Data()); + auto* dequant_data = static_cast(malloc(input_tensor->DataSize() * sizeof(float))); + if (dequant_data == nullptr) { + MS_LOG(ERROR) << "malloc faile"; + return RET_ERROR; + } + if (input_tensor->GetQuantParams().size() != kPerTensor) { + size_t channels = static_cast(input_tensor->Batch()); + if (input_tensor->GetQuantParams().size() != channels) { + MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; + return RET_ERROR; + } + size_t per_channel_size = input_tensor->DataSize() / channels; + auto quant_param = input_tensor->GetQuantParams(); + for (size_t i = 0; i < channels; i++) { + auto param = quant_param.at(i); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (size_t j = 0; j < per_channel_size; j++) { + dequant_data[per_channel_size * i + j] = static_cast( + (quant_data[per_channel_size * i + j] - zero_point) * scale); + } + } + } else { + auto quant_param = input_tensor->GetQuantParams(); + auto param = quant_param.front(); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (int64_t j = 0; j < input_tensor->DataSize(); j++) { + dequant_data[j] = static_cast((quant_data[j] - zero_point) * scale); + } + } + input_tensor->SetData(dequant_data); + return RET_OK; +} kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, @@ -61,6 +106,11 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectorData(); + if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + RestoreFullconnectWeight(inputs.at(kWeightIndex)); + } auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); if (!kernel) { MS_LOG(ERROR) << "kernel is nullptr."; @@ -73,6 +123,10 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector(opParameter->type_)); return nullptr; } + if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h index 15151ee72e..9707f19d0a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.h @@ -23,6 +23,7 @@ #include "nnacl/matmul_parameter.h" using mindspore::lite::Context; +static constexpr int kPerTensor = 1; namespace mindspore::kernel { class FullconnectionBaseCPUKernel : public LiteKernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc index 01b719e1b1..7ab281959e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc @@ -26,12 +26,65 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_MatMul; namespace mindspore::kernel { +int RestoreMatmulWeight(lite::tensor::Tensor *input_tensor) { + MS_ASSERT(input_tensor != nullptr); + if (input_tensor->data_type() != kNumberTypeUInt8) { + MS_LOG(ERROR) << "mat mul input type error" << input_tensor->data_type(); + return RET_ERROR; + } + if (input_tensor->GetQuantParams().empty()) { + MS_LOG(ERROR) << "no quant param"; + return RET_ERROR; + } + const auto* quant_data = static_cast(input_tensor->Data()); + auto* dequant_data = static_cast(malloc(input_tensor->DataSize() * sizeof(float))); + if (dequant_data == nullptr) { + MS_LOG(ERROR) << "malloc faile"; + return RET_ERROR; + } + + if (input_tensor->GetQuantParams().size() != kPerTensor) { + size_t channels = static_cast(input_tensor->Batch()); + if (input_tensor->GetQuantParams().size() != channels) { + MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; + return RET_ERROR; + } + size_t per_channel_size = input_tensor->DataSize() / channels; + auto quant_param = input_tensor->GetQuantParams(); + for (size_t i = 0; i < channels; i++) { + auto param = quant_param.at(i); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (size_t j = 0; j < per_channel_size; j++) { + dequant_data[per_channel_size * i + j] = static_cast( + (quant_data[per_channel_size * i + j] - zero_point) * scale); + } + } + } else { + auto quant_param = input_tensor->GetQuantParams(); + auto param = quant_param.front(); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (int64_t j = 0; j < input_tensor->DataSize(); j++) { + dequant_data[j] = static_cast((quant_data[j] - zero_point) * scale); + } + } + input_tensor->SetData(dequant_data); + return RET_OK; +} kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Concat); + + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->Data(); + if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + RestoreMatmulWeight(inputs.at(kWeightIndex)); + } + auto input_tensor = inputs.at(kInputIndex); auto data_type = input_tensor->data_type(); kernel::LiteKernel *kernel = nullptr; @@ -51,6 +104,12 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector(opParameter->type_)); return nullptr; } + + if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } + return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.h b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.h index 1a51529a99..dd33c90131 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/matmul_base.h @@ -23,6 +23,7 @@ #include "nnacl/matmul_parameter.h" using mindspore::lite::Context; +static constexpr int kPerTensor = 1; namespace mindspore::kernel { class MatmulBaseCPUKernel : public LiteKernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc index 75d568b609..a524f9dd87 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.cc @@ -69,8 +69,58 @@ int ArithmeticSelfCPUKernel::DoArithmeticSelf(int task_id) { } return RET_OK; } +int RestoreMulWeight(lite::tensor::Tensor *input_tensor) { + MS_ASSERT(input_tensor != nullptr); + if (input_tensor->data_type() != kNumberTypeUInt8) { + MS_LOG(ERROR) << "full connect input type error" << input_tensor->data_type(); + return RET_ERROR; + } + if (input_tensor->GetQuantParams().empty()) { + MS_LOG(ERROR) << "no quant param"; + return RET_ERROR; + } + const auto* quant_data = static_cast(input_tensor->Data()); + auto* dequant_data = static_cast(malloc(input_tensor->DataSize() * sizeof(float))); + if (dequant_data == nullptr) { + MS_LOG(ERROR) << "malloc faile"; + return RET_ERROR; + } + if (input_tensor->GetQuantParams().size() != kPerTensor) { + size_t channels = static_cast(input_tensor->Batch()); + if (input_tensor->GetQuantParams().size() != channels) { + MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; + return RET_ERROR; + } + size_t per_channel_size = input_tensor->DataSize() / channels; + auto quant_param = input_tensor->GetQuantParams(); + for (size_t i = 0; i < channels; i++) { + auto param = quant_param.at(i); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (size_t j = 0; j < per_channel_size; j++) { + dequant_data[per_channel_size * i + j] = static_cast( + (quant_data[per_channel_size * i + j] - zero_point) * scale); + } + } + } else { + auto quant_param = input_tensor->GetQuantParams(); + auto param = quant_param.front(); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (int64_t j = 0; j < input_tensor->DataSize(); j++) { + dequant_data[j] = static_cast((quant_data[j] - zero_point) * scale); + } + } + input_tensor->SetData(dequant_data); + return RET_OK; +} int ArithmeticSelfCPUKernel::Run() { + void *restore_data = nullptr; + if (primitive_->GetQuantType() == schema::QuantType_WeightQuant) { + restore_data = in_tensors_[1]->Data(); + RestoreMulWeight(in_tensors_[1]); + } auto ret = Prepare(); if (ret != RET_OK) { MS_LOG(ERROR) << "Prepare fail!ret: " << ret; @@ -85,6 +135,10 @@ int ArithmeticSelfCPUKernel::Run() { MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; return ret; } + if (primitive_->GetQuantType() == schema::QuantType_WeightQuant) { + in_tensors_[1]->FreeData(); + in_tensors_[1]->SetData(restore_data); + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h index 18fbd93d6d..f63034337f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self.h @@ -37,6 +37,7 @@ using mindspore::schema::PrimitiveType_Rsqrt; using mindspore::schema::PrimitiveType_Sin; using mindspore::schema::PrimitiveType_Sqrt; using mindspore::schema::PrimitiveType_Square; +static constexpr int kPerTensor = 1; namespace mindspore::kernel { class ArithmeticSelfCPUKernel : public LiteKernel { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc index aee1ff01ee..ccdbb6080a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.cc @@ -169,13 +169,63 @@ int ScaleCPUKernel::Run() { } return RET_OK; } +int RestoreScaleWeight(lite::tensor::Tensor *input_tensor) { + MS_ASSERT(input_tensor != nullptr); + if (input_tensor->data_type() != kNumberTypeUInt8) { + MS_LOG(ERROR) << "mat mul input type error" << input_tensor->data_type(); + return RET_ERROR; + } + if (input_tensor->GetQuantParams().empty()) { + MS_LOG(ERROR) << "no quant param"; + return RET_ERROR; + } + const auto* quant_data = static_cast(input_tensor->Data()); + auto* dequant_data = static_cast(malloc(input_tensor->DataSize() * sizeof(float))); + if (dequant_data == nullptr) { + MS_LOG(ERROR) << "malloc faile"; + return RET_ERROR; + } + if (input_tensor->GetQuantParams().size() != kPerTensor) { + size_t channels = static_cast(input_tensor->Batch()); + if (input_tensor->GetQuantParams().size() != channels) { + MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels; + return RET_ERROR; + } + size_t per_channel_size = input_tensor->DataSize() / channels; + auto quant_param = input_tensor->GetQuantParams(); + for (size_t i = 0; i < channels; i++) { + auto param = quant_param.at(i); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (size_t j = 0; j < per_channel_size; j++) { + dequant_data[per_channel_size * i + j] = static_cast( + (quant_data[per_channel_size * i + j] - zero_point) * scale); + } + } + } else { + auto quant_param = input_tensor->GetQuantParams(); + auto param = quant_param.front(); + auto scale = param.scale; + auto zero_point = param.zeroPoint; + for (int64_t j = 0; j < input_tensor->DataSize(); j++) { + dequant_data[j] = static_cast((quant_data[j] - zero_point) * scale); + } + } + input_tensor->SetData(dequant_data); + return RET_OK; +} kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_Scale); + auto *weight_tensor = inputs.at(kWeightIndex); + auto *restore_data = weight_tensor->Data(); + if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + RestoreScaleWeight(inputs.at(kWeightIndex)); + } if (opParameter == nullptr) { MS_LOG(ERROR) << "opParameter is nullptr"; return nullptr; @@ -193,6 +243,10 @@ kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vectorGetQuantType() == schema::QuantType_WeightQuant) { + weight_tensor->FreeData(); + weight_tensor->SetData(restore_data); + } return kernel; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h index 8071f16bdd..42b1daba5d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/scale.h @@ -21,6 +21,7 @@ #include "src/lite_kernel.h" #include "nnacl/fp32/scale.h" +static constexpr int kPerTensor = 1; namespace mindspore::kernel { class ScaleCPUKernel : public LiteKernel { diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 02cd3f0ca0..c2ddb9bbc9 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -35,6 +35,7 @@ Flags::Flags() { AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT"); AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); + AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold", "convWeightQuantChannelThreshold", "16"); diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 4badecb0e8..ac06dbf9a4 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -49,13 +49,13 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { return RET_ERROR; } - auto inputNode = cnode->input(2); - if (!inputNode->isa()) { + auto input_node = cnode->input(2); + if (!input_node->isa()) { return RET_ERROR; } - auto paramNode = inputNode->cast(); - if (!paramNode->has_default()) { + auto param_node = input_node->cast(); + if (!param_node->has_default()) { return RET_ERROR; } @@ -65,14 +65,26 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { auto op_type = (schema::PrimitiveType)primitive_c->Type(); bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false; - ParamValueLitePtr param_value = std::static_pointer_cast(paramNode->default_param()); + ParamValueLitePtr param_value = std::static_pointer_cast(param_node->default_param()); auto status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, 255, 0, bitNum, true, depthwise); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } + // set dtype param_value->set_tensor_type(kNumberTypeUInt8); + auto abstractBase = param_node->abstract(); + if (abstractBase == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); + return RET_ERROR; + } + if (!utils::isa(abstractBase)) { + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); + return RET_ERROR; + } + auto abstractTensor = utils::cast(abstractBase); + abstractTensor->element()->set_type(TypeIdToType(kNumberTypeUInt8)); primitive_c->SetQuantType(schema::QuantType_WeightQuant); } @@ -86,14 +98,14 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { } ParamValueLitePtr param_value = nullptr; + ParameterPtr param_node = nullptr; for (size_t i = 1; i < node->size(); i++) { auto inputNode = node->input(i); if (inputNode->isa() == true) { - auto paramNode = inputNode->cast(); - if ((paramNode != nullptr) && (paramNode->has_default() == true)) { - param_value = std::static_pointer_cast(paramNode->default_param()); + param_node = inputNode->cast(); + if ((param_node != nullptr) && (param_node->has_default() == true)) { + param_value = std::static_pointer_cast(param_node->default_param()); if ((param_value == nullptr) || (param_value->tensor_size() == 0) - || (param_value->tensor_shape().size() != 4) || (param_value->tensor_addr() == nullptr) || (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) { param_value = nullptr; @@ -115,12 +127,26 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { return RET_ERROR; } + std::vector quant_params; + primitive_c->AddInputQuantParam(quant_params); auto status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, 255, 0, bitNum, true, false); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } param_value->set_tensor_type(kNumberTypeUInt8); + // set dtype + auto abstractBase = param_node->abstract(); + if (abstractBase == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); + return RET_ERROR; + } + if (!utils::isa(abstractBase)) { + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); + return RET_ERROR; + } + auto abstractTensor = utils::cast(abstractBase); + abstractTensor->element()->set_type(TypeIdToType(kNumberTypeUInt8)); primitive_c->SetQuantType(schema::QuantType_WeightQuant); }