Dequantize weight parameters for quantized tflite model with the quantType of None.

pull/5957/head
wsc 5 years ago
parent 483b364d92
commit a9b8b39cac

@ -236,7 +236,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
}
@ -265,7 +265,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
return nullptr;
}
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}

@ -133,7 +133,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
}
@ -156,7 +156,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
return nullptr;
}
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}

Loading…
Cancel
Save