|
|
|
@ -21,105 +21,135 @@
|
|
|
|
|
#include "nnacl/matmul_parameter.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore::lite {
|
|
|
|
|
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first) {
|
|
|
|
|
int DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first, TypeId dst_data_type) {
|
|
|
|
|
MS_ASSERT(input_tensor != nullptr);
|
|
|
|
|
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
|
|
|
|
|
MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (input_tensor->quant_params().empty()) {
|
|
|
|
|
MS_LOG(ERROR) << "No quant param.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (input_tensor->data_type() == kNumberTypeInt16) {
|
|
|
|
|
return DequantData<int16_t>(input_tensor, channel_first);
|
|
|
|
|
if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat32) {
|
|
|
|
|
auto new_const_data = DequantData<int16_t, float>(input_tensor, channel_first);
|
|
|
|
|
input_tensor->set_data(new_const_data);
|
|
|
|
|
input_tensor->set_own_data(true);
|
|
|
|
|
input_tensor->set_data_type(dst_data_type);
|
|
|
|
|
} else if (input_tensor->data_type() == kNumberTypeInt16 && dst_data_type == kNumberTypeFloat16) {
|
|
|
|
|
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
|
|
|
|
|
auto new_const_data = DequantData<int16_t, float16_t>(input_tensor, channel_first);
|
|
|
|
|
input_tensor->set_data(new_const_data);
|
|
|
|
|
input_tensor->set_own_data(true);
|
|
|
|
|
input_tensor->set_data_type(dst_data_type);
|
|
|
|
|
#else
|
|
|
|
|
MS_LOG(ERROR) << "Float16 is not supported";
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
#endif
|
|
|
|
|
} else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat32) {
|
|
|
|
|
auto new_const_data = DequantData<int8_t, float>(input_tensor, channel_first);
|
|
|
|
|
input_tensor->set_data(new_const_data);
|
|
|
|
|
input_tensor->set_own_data(true);
|
|
|
|
|
input_tensor->set_data_type(dst_data_type);
|
|
|
|
|
} else if (input_tensor->data_type() == kNumberTypeInt8 && dst_data_type == kNumberTypeFloat16) {
|
|
|
|
|
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
|
|
|
|
|
auto new_const_data = DequantData<int8_t, float16_t>(input_tensor, channel_first);
|
|
|
|
|
input_tensor->set_data(new_const_data);
|
|
|
|
|
input_tensor->set_own_data(true);
|
|
|
|
|
input_tensor->set_data_type(dst_data_type);
|
|
|
|
|
#else
|
|
|
|
|
MS_LOG(ERROR) << "Float16 is not supported";
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
return DequantData<int8_t>(input_tensor, channel_first);
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported dequant from data_type(" << (input_tensor->data_type()) << ") to data_type("
|
|
|
|
|
<< dst_data_type << ")";
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
|
|
|
|
|
MS_ASSERT(input_tensor != nullptr);
|
|
|
|
|
MS_ASSERT(unpack_int_data != nullptr);
|
|
|
|
|
auto quant_params = input_tensor->quantParams();
|
|
|
|
|
if (quant_params == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "low bits quantparams is empty.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
int DequantUtil::DecodeHuffmanCode(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
|
|
|
|
|
MS_ASSERT(dst_tensor != nullptr);
|
|
|
|
|
if (!dst_tensor->IsConst() || !src_tensor.enableHuffmanCode()) {
|
|
|
|
|
return RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
auto enable_huffman_code = input_tensor->enableHuffmanCode();
|
|
|
|
|
if (enable_huffman_code) {
|
|
|
|
|
std::string encode_str(input_tensor->data()->begin(), input_tensor->data()->end());
|
|
|
|
|
auto huffman_decode = std::make_unique<lite::HuffmanDecode>();
|
|
|
|
|
auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
auto data = reinterpret_cast<const char *>(src_tensor.data()->data());
|
|
|
|
|
MS_ASSERT(data != nullptr);
|
|
|
|
|
std::string encode_str(data, src_tensor.data()->size());
|
|
|
|
|
dst_tensor->set_data(nullptr);
|
|
|
|
|
auto ret = dst_tensor->MallocData();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Malloc tensor data failed";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
int origin_bit = quant_params->Get(0)->numBits();
|
|
|
|
|
if (origin_bit < 8 && origin_bit > 0) {
|
|
|
|
|
UnPackUtil<int8_t, uint8_t>(input_tensor, origin_bit, unpack_int_data);
|
|
|
|
|
} else if (origin_bit < 16 && origin_bit > 8) {
|
|
|
|
|
UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data);
|
|
|
|
|
auto dst_data = dst_tensor->data_c();
|
|
|
|
|
MS_ASSERT(dst_data != nullptr);
|
|
|
|
|
ret = HuffmanDecode::DoHuffmanDecode(encode_str, dst_data);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(OpParameter *op_param,
|
|
|
|
|
const std::vector<Tensor *> &in_tensors,
|
|
|
|
|
TypeId data_type, bool need_restore) {
|
|
|
|
|
std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data;
|
|
|
|
|
if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) {
|
|
|
|
|
auto input_i = 0;
|
|
|
|
|
for (auto weight_tensor : in_tensors) {
|
|
|
|
|
MS_ASSERT(weight_tensor != nullptr);
|
|
|
|
|
input_i++;
|
|
|
|
|
auto channel_first = true;
|
|
|
|
|
if (op_param->type_ == schema::PrimitiveType_MatMul && weight_tensor->shape().size() == 2) {
|
|
|
|
|
auto param = reinterpret_cast<MatMulParameter *>(op_param);
|
|
|
|
|
if (input_i == 1) {
|
|
|
|
|
channel_first = !param->a_transpose_;
|
|
|
|
|
} else if (input_i == 2) {
|
|
|
|
|
channel_first = param->b_transpose_;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(WARNING) << "unexpected input_i";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto *restore_data = weight_tensor->data_c();
|
|
|
|
|
auto restore_type = weight_tensor->data_type();
|
|
|
|
|
bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
|
|
|
|
|
restore_data != nullptr &&
|
|
|
|
|
(restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16);
|
|
|
|
|
if (dequant_flag) {
|
|
|
|
|
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor, channel_first);
|
|
|
|
|
if (dequant_weight == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
|
|
|
|
return tensor_origin_data;
|
|
|
|
|
}
|
|
|
|
|
if (need_restore) {
|
|
|
|
|
tensor_origin_data[weight_tensor] = {restore_type, restore_data};
|
|
|
|
|
} else {
|
|
|
|
|
weight_tensor->FreeData();
|
|
|
|
|
}
|
|
|
|
|
weight_tensor->set_data(dequant_weight);
|
|
|
|
|
weight_tensor->set_data_type(kNumberTypeFloat32);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int DequantUtil::UnPackToInt(const schema::Tensor &src_tensor, lite::Tensor *dst_tensor) {
|
|
|
|
|
MS_ASSERT(dst_tensor != nullptr);
|
|
|
|
|
if (!dst_tensor->IsConst()) {
|
|
|
|
|
return RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
auto quant_params = src_tensor.quantParams();
|
|
|
|
|
if (quant_params == nullptr || quant_params->size() == 0) {
|
|
|
|
|
return RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
auto quant_param = quant_params->Get(0);
|
|
|
|
|
if (quant_param == nullptr || !quant_param->inited()) {
|
|
|
|
|
return RET_NO_CHANGE;
|
|
|
|
|
}
|
|
|
|
|
auto dst_data = dst_tensor->data_c();
|
|
|
|
|
if (dst_data != nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "lite Tensor has already malloced data";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto ret = dst_tensor->MallocData();
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Malloc tensor data failed";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
dst_data = dst_tensor->data_c();
|
|
|
|
|
int origin_bit = quant_param->numBits();
|
|
|
|
|
if (origin_bit < 8 && origin_bit > 0) {
|
|
|
|
|
UnPackUtil<int8_t, uint8_t>(&src_tensor, origin_bit, dst_data);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
} else if (origin_bit < 16 && origin_bit > 8) {
|
|
|
|
|
UnPackUtil<int16_t, uint16_t>(&src_tensor, origin_bit, dst_data);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported bit number: " << origin_bit;
|
|
|
|
|
return RET_NOT_SUPPORT;
|
|
|
|
|
}
|
|
|
|
|
return tensor_origin_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DequantUtil::RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map) {
|
|
|
|
|
for (auto &kv : tensor_origin_data_map) {
|
|
|
|
|
auto *tensor = kv.first;
|
|
|
|
|
auto type_id = kv.second.first;
|
|
|
|
|
auto data = kv.second.second;
|
|
|
|
|
tensor->FreeData();
|
|
|
|
|
tensor->set_data_type(type_id);
|
|
|
|
|
tensor->set_data(data);
|
|
|
|
|
Tensor *DequantUtil::DequantTensor(Tensor *tensor, TypeId data_type, bool channel_first, TypeId dst_data_type) {
|
|
|
|
|
MS_ASSERT(tensor != nullptr);
|
|
|
|
|
Tensor *restore_tensor = nullptr;
|
|
|
|
|
if (!tensor->IsConst() || !(data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto restore_type = tensor->data_type();
|
|
|
|
|
bool need_dequant = !tensor->quant_params().empty() && tensor->quant_params().front().inited &&
|
|
|
|
|
(restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16);
|
|
|
|
|
if (!need_dequant) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
restore_tensor = Tensor::CopyTensor(*tensor, false);
|
|
|
|
|
restore_tensor->set_data(tensor->data_c());
|
|
|
|
|
restore_tensor->set_own_data(tensor->own_data());
|
|
|
|
|
auto ret = DequantUtil::DequantWeight(tensor, channel_first, dst_data_type);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Dequant data failed: " << ret;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return restore_tensor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace mindspore::lite
|
|
|
|
|