|
|
|
@ -24,120 +24,49 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
|
|
|
|
|
STATUS HuffmanEncode::GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value) {
|
|
|
|
|
if (!input_node->isa<Parameter>()) {
|
|
|
|
|
return RET_CONTINUE;
|
|
|
|
|
}
|
|
|
|
|
auto abstract_base = input_node->abstract();
|
|
|
|
|
if (abstract_base == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
|
|
|
|
if (abstract_tensor->element() == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope();
|
|
|
|
|
STATUS HuffmanEncode::DoHuffmanEncode(const ParamValueLitePtr &weight, const std::shared_ptr<PrimitiveC> &primitive_c,
|
|
|
|
|
void *quant_datas, const size_t &bit_num) {
|
|
|
|
|
if (quant_datas == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "quant data is nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto tensor_type = abstract_tensor->element()->GetTypeTrack();
|
|
|
|
|
MS_ASSERT(tensor_type != nullptr);
|
|
|
|
|
auto tensor_type_id = tensor_type->type_id();
|
|
|
|
|
if (tensor_type_id != kNumberTypeInt8) {
|
|
|
|
|
return RET_CONTINUE;
|
|
|
|
|
}
|
|
|
|
|
auto param_node = input_node->cast<ParameterPtr>();
|
|
|
|
|
if (param_node == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!param_node->has_default()) {
|
|
|
|
|
MS_LOG(WARNING) << "param_node don't have default: " << input_node->fullname_with_scope();
|
|
|
|
|
return RET_CONTINUE;
|
|
|
|
|
}
|
|
|
|
|
*param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS HuffmanEncode::DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num) {
|
|
|
|
|
auto cnodes = func_graph->GetOrderedCnodes();
|
|
|
|
|
for (auto &cnode : cnodes) {
|
|
|
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
|
|
|
|
if (primitive_c == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (primitive_c->quant_type() != schema::QuantType_WeightQuant) {
|
|
|
|
|
continue;
|
|
|
|
|
auto *raw_datas = static_cast<int8_t *>(quant_datas);
|
|
|
|
|
size_t elem_count = weight->tensor_shape_size();
|
|
|
|
|
size_t packed_size = elem_count * bit_num;
|
|
|
|
|
|
|
|
|
|
HuffmanPriorityQueue pq;
|
|
|
|
|
auto status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "GetHuffmanPriorityQueue failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
status = BuildHuffmanTree(&pq);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "BuildHuffmanTree failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
status = DoHuffmanCompress(raw_datas, elem_count);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "DoHuffmanCompress failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
size_t ch_size = huffman_encoded_str_.length();
|
|
|
|
|
if (ch_size < packed_size) {
|
|
|
|
|
auto encode_data = new (std::nothrow) char[ch_size];
|
|
|
|
|
if (encode_data == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new char[] failed.";
|
|
|
|
|
return RET_MEMORY_FAILED;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
|
|
|
|
auto input_node = cnode->input(i);
|
|
|
|
|
ParamValueLitePtr param_value;
|
|
|
|
|
auto status = GetParamValueLitePtr(input_node, ¶m_value);
|
|
|
|
|
if (status == RET_CONTINUE) {
|
|
|
|
|
continue;
|
|
|
|
|
} else if (status == RET_ERROR) {
|
|
|
|
|
MS_LOG(ERROR) << "Get param value lite ptr failed. " << cnode->fullname_with_scope();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
size_t elem_count = param_value->tensor_shape_size();
|
|
|
|
|
size_t packed_size = param_value->tensor_size();
|
|
|
|
|
auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr());
|
|
|
|
|
if (raw_datas == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "rawDatas is nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (bit_num < 8 && bit_num > 0) {
|
|
|
|
|
auto dst_data = new (std::nothrow) int8_t[elem_count];
|
|
|
|
|
if (dst_data == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new int8_t[] failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
DequantUtil::UnpackUtil<int8_t, uint8_t>(raw_datas, packed_size, bit_num, dst_data);
|
|
|
|
|
if (memcpy_s(raw_datas, elem_count, dst_data, elem_count) != EOK) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s failed.";
|
|
|
|
|
return RET_MEMORY_FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
HuffmanPriorityQueue pq;
|
|
|
|
|
status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "GetHuffmanPriorityQueue failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
status = BuildHuffmanTree(&pq);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "BuildHuffmanTree failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
status = DoHuffmanCompress(raw_datas, elem_count);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "DoHuffmanCompress failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
size_t ch_size = huffman_encoded_str_.length();
|
|
|
|
|
if (ch_size < packed_size) {
|
|
|
|
|
auto encode_data = new (std::nothrow) char[ch_size];
|
|
|
|
|
if (encode_data == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new char[] failed.";
|
|
|
|
|
delete[] raw_datas;
|
|
|
|
|
return RET_MEMORY_FAILED;
|
|
|
|
|
}
|
|
|
|
|
delete[] raw_datas;
|
|
|
|
|
if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s failed.";
|
|
|
|
|
delete[] encode_data;
|
|
|
|
|
return RET_MEMORY_FAILED;
|
|
|
|
|
}
|
|
|
|
|
param_value->SetTensorData(encode_data, ch_size);
|
|
|
|
|
primitive_c->SetEnableHuffmanCode(true);
|
|
|
|
|
}
|
|
|
|
|
huffman_encoded_str_.clear();
|
|
|
|
|
huffman_table_.clear();
|
|
|
|
|
if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s failed.";
|
|
|
|
|
delete[] encode_data;
|
|
|
|
|
return RET_MEMORY_FAILED;
|
|
|
|
|
}
|
|
|
|
|
weight->SetTensorData(encode_data, ch_size);
|
|
|
|
|
primitive_c->set_enable_huffman_code(true);
|
|
|
|
|
}
|
|
|
|
|
huffman_encoded_str_.clear();
|
|
|
|
|
huffman_table_.clear();
|
|
|
|
|
return RET_SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|