|
|
|
@ -26,56 +26,6 @@ using mindspore::lite::RET_OK;
|
|
|
|
|
using mindspore::schema::PrimitiveType_MatMul;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
int RestoreMatmulWeight(lite::Tensor *input_tensor) {
|
|
|
|
|
MS_ASSERT(input_tensor != nullptr);
|
|
|
|
|
if (input_tensor->data_type() != kNumberTypeUInt8) {
|
|
|
|
|
MS_LOG(ERROR) << "mat mul input type error" << input_tensor->data_type();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (input_tensor->GetQuantParams().empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "no quant param";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
const auto *quant_data = static_cast<const uint8_t *>(input_tensor->MutableData());
|
|
|
|
|
if (quant_data == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "input_tensor MutableData is nullptr.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto *dequant_data = static_cast<float *>(malloc(input_tensor->ElementsNum() * sizeof(float)));
|
|
|
|
|
if (dequant_data == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc faile";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (input_tensor->GetQuantParams().size() != kPerTensor) {
|
|
|
|
|
size_t channels = static_cast<size_t>(input_tensor->Batch());
|
|
|
|
|
if (input_tensor->GetQuantParams().size() != channels) {
|
|
|
|
|
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
size_t per_channel_size = input_tensor->ElementsNum() / channels;
|
|
|
|
|
auto quant_param = input_tensor->GetQuantParams();
|
|
|
|
|
for (size_t i = 0; i < channels; i++) {
|
|
|
|
|
auto param = quant_param.at(i);
|
|
|
|
|
auto scale = param.scale;
|
|
|
|
|
auto zero_point = param.zeroPoint;
|
|
|
|
|
for (size_t j = 0; j < per_channel_size; j++) {
|
|
|
|
|
dequant_data[per_channel_size * i + j] =
|
|
|
|
|
static_cast<float>((quant_data[per_channel_size * i + j] - zero_point) * scale);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto quant_param = input_tensor->GetQuantParams();
|
|
|
|
|
auto param = quant_param.front();
|
|
|
|
|
auto scale = param.scale;
|
|
|
|
|
auto zero_point = param.zeroPoint;
|
|
|
|
|
for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) {
|
|
|
|
|
dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
input_tensor->SetData(dequant_data);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &inputs,
|
|
|
|
|
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
|
|
|
|
const lite::Context *ctx, const kernel::KernelKey &desc,
|
|
|
|
@ -89,8 +39,13 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
|
|
|
|
|
MS_LOG(ERROR) << "weight_tensor MutableData is nullptr.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
|
|
|
RestoreMatmulWeight(inputs.at(kWeightIndex));
|
|
|
|
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
|
|
|
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
|
|
|
|
if (dequant_weight == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
weight_tensor->SetData(dequant_weight);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto input_tensor = inputs.at(kInputIndex);
|
|
|
|
@ -103,6 +58,10 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
|
|
|
|
|
}
|
|
|
|
|
if (kernel == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
|
|
|
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
|
|
|
weight_tensor->FreeData();
|
|
|
|
|
weight_tensor->SetData(restore_data);
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto ret = kernel->Init();
|
|
|
|
@ -110,10 +69,14 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::Tensor *> &in
|
|
|
|
|
delete kernel;
|
|
|
|
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
|
|
|
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
|
|
|
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
|
|
|
weight_tensor->FreeData();
|
|
|
|
|
weight_tensor->SetData(restore_data);
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
|
|
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
|
|
|
|
weight_tensor->FreeData();
|
|
|
|
|
weight_tensor->SetData(restore_data);
|
|
|
|
|
}
|
|
|
|
|