remove weight quant judge

pull/6893/head
kai00 4 years ago
parent 1b07551bd4
commit 33de4a3e65

@ -35,9 +35,7 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c(); auto *restore_data = weight_tensor->data_c();
auto is_const_quant_weight = auto is_const_quant_weight = (restore_data != nullptr) && (weight_tensor->data_type() == kNumberTypeInt8);
(restore_data != nullptr) &&
(weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant);
if (is_const_quant_weight) { if (is_const_quant_weight) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {

@ -145,7 +145,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
@ -165,7 +165,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
} }
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
@ -177,14 +177,14 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
return nullptr; return nullptr;
} }
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);

@ -189,7 +189,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
@ -224,7 +224,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
} }
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
@ -236,14 +236,14 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
delete kernel; delete kernel;
MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
return nullptr; return nullptr;
} }
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);

@ -204,7 +204,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
@ -217,7 +217,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
@ -229,14 +229,14 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
return nullptr; return nullptr;
} }
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);

@ -211,7 +211,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
@ -224,7 +224,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
auto kernel = new (std::nothrow) DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive); auto kernel = new (std::nothrow) DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
@ -236,14 +236,14 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
return nullptr; return nullptr;
} }
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->set_data_type(kNumberTypeInt8); weight_tensor->set_data_type(kNumberTypeInt8);
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);

@ -186,7 +186,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
@ -206,7 +206,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
} }
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
@ -217,14 +217,14 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
return nullptr; return nullptr;
} }
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }

@ -133,7 +133,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
@ -151,7 +151,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
} }
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
@ -162,14 +162,14 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
return nullptr; return nullptr;
} }
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }

@ -237,7 +237,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D); MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
@ -248,7 +248,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive); auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
@ -259,14 +259,14 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
return nullptr; return nullptr;
} }
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }

@ -201,7 +201,7 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D); MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData(); auto *restore_data = weight_tensor->MutableData();
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
@ -213,7 +213,7 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive); new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr."; MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
@ -224,13 +224,13 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
delete kernel; delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }
return nullptr; return nullptr;
} }
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (weight_tensor->data_type() == kNumberTypeInt8) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }

Loading…
Cancel
Save