From a9b8b39cac5caf5cea99620aebd7a24cb87eb4b7 Mon Sep 17 00:00:00 2001 From: wsc Date: Wed, 9 Sep 2020 17:02:40 +0800 Subject: [PATCH] Dequantize weight parameters for quantized tflite model with the quantType of None. --- mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc | 4 ++-- .../lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index 04cae95820..b9f3a5896a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -236,7 +236,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); } @@ -265,7 +265,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector & return nullptr; } - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc index e9ac09a5b1..b24e05cf9b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise.cc @@ -133,7 +133,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector auto *weight_tensor = inputs.at(kWeightIndex); auto *restore_data = weight_tensor->MutableData(); - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex)); } @@ -156,7 +156,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector return nullptr; } - if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { + if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); }