!8212 [MSLITE] add set data type in dequant weight

Merge pull request !8212 from ghzl/fix-dequant-bug-set-type
pull/8212/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 57a8911eb0

@ -42,6 +42,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -58,6 +59,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -70,12 +72,14 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

@ -139,7 +139,8 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -166,6 +167,7 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -178,12 +180,14 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

@ -181,7 +181,8 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -225,6 +226,7 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -237,12 +239,14 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

@ -203,7 +203,8 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -223,6 +224,7 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -235,12 +237,14 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

@ -215,7 +215,8 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
auto dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -243,6 +244,7 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -255,12 +257,14 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

@ -237,6 +237,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -255,6 +256,7 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -267,12 +269,14 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

@ -251,6 +251,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) {
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -268,6 +269,7 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -280,12 +282,14 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

@ -303,7 +303,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
}
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -328,6 +329,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(op_parameter);
return nullptr;
@ -340,6 +342,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
@ -347,6 +350,7 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;

@ -123,8 +123,8 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
@ -147,6 +147,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -159,6 +160,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
@ -166,6 +168,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;

@ -233,7 +233,8 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -260,6 +261,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -272,6 +274,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
@ -279,6 +282,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;

@ -195,7 +195,8 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->MutableData();
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -214,6 +215,7 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -226,12 +228,14 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

@ -343,6 +343,7 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->GetQuantParams().empty() && weight_tensor->GetQuantParams().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
@ -361,6 +362,7 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
@ -373,6 +375,7 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
@ -380,6 +383,7 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;

Loading…
Cancel
Save