From 70d575c76a95f59364b4f82604c20c0d81a8cfcd Mon Sep 17 00:00:00 2001 From: jianghui58 Date: Thu, 21 Jan 2021 22:56:49 +0800 Subject: [PATCH] weight quant support huffman code --- mindspore/lite/schema/model.fbs | 1 + mindspore/lite/src/CMakeLists.txt | 1 + mindspore/lite/src/huffman_decode.cc | 168 +++++++++ mindspore/lite/src/huffman_decode.h | 77 +++++ mindspore/lite/src/lite_session.cc | 24 ++ mindspore/lite/src/ops/primitive_c.cc | 4 + mindspore/lite/src/ops/primitive_c.h | 5 + mindspore/lite/src/tensor.cc | 4 + mindspore/lite/src/tensor.h | 5 + mindspore/lite/test/CMakeLists.txt | 1 + .../lite/test/models_tflite_weightquant.cfg | 3 +- mindspore/lite/test/run_benchmark_nets.sh | 39 ++- .../lite/tools/anf_exporter/anf_exporter.cc | 320 ++++++++++-------- .../lite/tools/anf_exporter/anf_exporter.h | 22 +- mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../lite/tools/converter/anf_transform.cc | 21 ++ .../lite/tools/converter/anf_transform.h | 2 + .../lite/tools/converter/converter_flags.cc | 150 +++++--- .../lite/tools/converter/converter_flags.h | 12 + .../tools/converter/quantizer/CMakeLists.txt | 1 + .../converter/quantizer/huffman_encode.cc | 281 +++++++++++++++ .../converter/quantizer/huffman_encode.h | 77 +++++ 22 files changed, 1016 insertions(+), 203 deletions(-) create mode 100644 mindspore/lite/src/huffman_decode.cc create mode 100644 mindspore/lite/src/huffman_decode.h create mode 100644 mindspore/lite/tools/converter/quantizer/huffman_encode.cc create mode 100644 mindspore/lite/tools/converter/quantizer/huffman_encode.h diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index e37bc7741c..0b24e00ca9 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -57,6 +57,7 @@ table Tensor { quantParams: [QuantParam]; quantClusters: [float]; name: string; + enableHuffmanCode: bool = false; } union PrimitiveType { diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 69ef57da7a..d610f0df09 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -37,6 +37,7 @@ set(LITE_SRC ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc ${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc + ${CMAKE_CURRENT_SOURCE_DIR}/huffman_decode.cc ) if(SUPPORT_GPU) diff --git a/mindspore/lite/src/huffman_decode.cc b/mindspore/lite/src/huffman_decode.cc new file mode 100644 index 0000000000..d4176770df --- /dev/null +++ b/mindspore/lite/src/huffman_decode.cc @@ -0,0 +1,168 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/huffman_decode.h" + +namespace mindspore { +namespace lite { + +STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) { + if (decoded_data == nullptr) { + MS_LOG(ERROR) << "decoded_data is nullptr."; + return RET_ERROR; + } + + int status; + std::string huffman_decoded_str = ""; + + auto key_pos = input_str.find_first_of('#'); + auto code_pos = input_str.find_first_of('#', key_pos + 1); + auto key = input_str.substr(0, key_pos); + auto code = input_str.substr(key_pos + 1, code_pos - key_pos - 1); + auto encoded_data = input_str.substr(code_pos + 1); + + auto root = new (std::nothrow) HuffmanNode(); + if (root == nullptr) { + MS_LOG(ERROR) << "new HuffmanNode failed."; + return RET_MEMORY_FAILED; + } + root->left = nullptr; + root->right = nullptr; + root->parent = nullptr; + + status = RebuildHuffmanTree(key, code, root); + if (status != RET_OK) { + MS_LOG(ERROR) << "Rebuild huffman tree failed."; + delete root; + return status; + } + + status = DoHuffmanDecompress(root, encoded_data, &huffman_decoded_str); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoHuffmanDecompress failed."; + delete root; + return status; + } + + size_t len = huffman_decoded_str.length(); + memcpy(decoded_data, huffman_decoded_str.c_str(), len); + + delete root; + return RET_OK; +} + +STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) { + HuffmanNodePtr cur_node, tmp_node, new_node; + + auto huffman_keys = Str2Vec(std::move(keys)); + auto huffman_codes = Str2Vec(std::move(codes)); + + for (size_t i = 0; i < huffman_codes.size(); ++i) { + auto key = stoi(huffman_keys[i]); + auto code = huffman_codes[i]; + auto code_len = code.length(); + cur_node = root; + for (size_t j = 0; j < code_len; ++j) { + if (code[j] == '0') { + tmp_node = cur_node->left; + } else if (code[j] == '1') { + tmp_node = cur_node->right; + } else { + MS_LOG(ERROR) << "find huffman code is not 0 or 1"; + return RET_ERROR; + } + + if (tmp_node == nullptr) { + new_node = new (std::nothrow) HuffmanNode(); + if (new_node == nullptr) { + MS_LOG(ERROR) << "new HuffmanNode failed."; + return RET_MEMORY_FAILED; + } + this->huffman_nodes_.push_back(new_node); + new_node->left = nullptr; + new_node->right = nullptr; + new_node->parent = cur_node; + + if (j == code_len - 1) { + new_node->key = key; + new_node->code = code; + } + + if (code[j] == '0') { + cur_node->left = new_node; + } else { + cur_node->right = new_node; + } + + tmp_node = new_node; + } else if (j == code_len - 1) { + MS_LOG(ERROR) << "the huffman code is incomplete."; + return RET_ERROR; + } else if (tmp_node->left == nullptr && tmp_node->right == nullptr) { + MS_LOG(ERROR) << "the huffman code is incomplete"; + return RET_ERROR; + } + cur_node = tmp_node; + } + } + return RET_OK; +} + +STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) { + HuffmanNodePtr cur_node = root; + bool pseudo_eof = false; + size_t pos = 0; + unsigned char flag; + + decoded_str->clear(); + while (pos < encoded_data.length()) { + auto u_char = static_cast(encoded_data[pos]); + flag = 0x80; + for (size_t i = 0; i < 8; ++i) { // traverse the 8 bit num, to find the leaf node + if (u_char & flag) { + cur_node = cur_node->right; + } else { + cur_node = cur_node->left; + } + if (cur_node->left == nullptr && cur_node->right == nullptr) { + auto key = cur_node->key; + if (key == PSEUDO_EOF) { + pseudo_eof = true; + break; + } else { + *decoded_str += static_cast(cur_node->key); + cur_node = root; + } + } + flag = flag >> 1; + } + pos++; + if (pseudo_eof) { + break; + } + } + return RET_OK; +} + +huffman_decode::~huffman_decode() { + for (auto &node : this->huffman_nodes_) { + delete node; + } + this->huffman_nodes_.resize(0); +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/huffman_decode.h b/mindspore/lite/src/huffman_decode.h new file mode 100644 index 0000000000..dec0182d7e --- /dev/null +++ b/mindspore/lite/src/huffman_decode.h @@ -0,0 +1,77 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_MINDSPORE_LITE_SRC_HUFFMAN_DECODE_H_ +#define MINDSPORE_LITE_MINDSPORE_LITE_SRC_HUFFMAN_DECODE_H_ + +#include +#include +#include +#include + +#include "include/errorcode.h" +#include "src/common/log_adapter.h" + +namespace mindspore { +namespace lite { + +const int PSEUDO_EOF = 128; + +struct HuffmanNode { + int key; + unsigned int freq; + std::string code; + HuffmanNode *left, *right, *parent; +}; +using HuffmanNodePtr = HuffmanNode *; + +class huffman_decode { + public: + huffman_decode() = default; + + ~huffman_decode(); + + STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data); + + private: + std::vector huffman_nodes_; + STATUS RebuildHuffmanTree(std::string key, std::string code, const HuffmanNodePtr &root); + + STATUS DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str); + + std::vector Str2Vec(std::string s) { + size_t i = 0; + std::vector vec; + while (i < s.length()) { + size_t j = i; + while (j < s.length() && s[j] != ' ') { + j++; + } + if (j != i) { + vec.push_back(s.substr(i, j - i)); + i = j + 1; + } else { + i = j; + } + } + return vec; + } +}; + +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_MINDSPORE_LITE_SRC_HUFFMAN_DECODE_H_ diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 6348521714..dad924459a 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -28,6 +28,7 @@ #include "src/kernel_registry.h" #include "src/lite_model.h" #include "src/dequant.h" +#include "src/huffman_decode.h" #if SUPPORT_NPU #include "src/runtime/agent/npu/npu_manager.h" #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" @@ -74,6 +75,7 @@ void LiteSession::ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lit dst_tensor->AddQuantParam(quant_arg); } } + dst_tensor->SetEnableHuffmanCode(src_tensor->enableHuffmanCode()); auto quant_clusters = src_tensor->quantClusters(); if (quant_clusters != nullptr) { std::vector clusters; @@ -94,6 +96,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde int org_size = dst_tensor->Size(); return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16); }; + auto NeedHuffmanDecode = [&src_tensor, &dst_tensor]() -> bool { + auto data_type = src_tensor->dataType(); + auto enable_huffman_code = src_tensor->enableHuffmanCode(); + int pack_size = src_tensor->data()->size(); + int org_size = dst_tensor->Size(); + return (pack_size != org_size) && (data_type == kNumberTypeInt8) && enable_huffman_code; + }; auto src_category = TensorCategory(src_tensor); if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { @@ -107,6 +116,21 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde return RET_ERROR; } } else { + if (NeedHuffmanDecode()) { + auto dst_data = dst_tensor->MutableData(); + if (dst_data == nullptr) { + MS_LOG(ERROR) << "Data from tensor is nullptr"; + return RET_NULL_PTR; + } + std::string encode_str(src_tensor->data()->begin(), src_tensor->data()->end()); + auto huffman_decode = std::make_unique(); + auto ret = huffman_decode->DoHuffmanDecode(encode_str, dst_data); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoHuffmanDecode failed."; + return ret; + } + copyed_tensor_idxes_.emplace_back(tensor_index); + } if (WeightTensorNeedCopy(model, tensor_index)) { auto dst_data = dst_tensor->MutableData(); if (dst_data == nullptr) { diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index b6b242297a..3ab2230e6c 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -450,6 +450,10 @@ void PrimitiveC::set_quant_type(const schema::QuantType &quant_type) { this->qua schema::QuantType PrimitiveC::quant_type() const { return quant_type_; } +bool PrimitiveC::IsEnableHuffmanCode() const { return enableHuffmanCode; } + +void PrimitiveC::SetEnableHuffmanCode(bool enableHuffmanCode) { this->enableHuffmanCode = enableHuffmanCode; } + std::shared_ptr GetReturnPrim() { auto return_primitiveT = new (std::nothrow) schema::PrimitiveT; if (return_primitiveT == nullptr) { diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index d7e5446e55..abf41457cc 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -123,6 +123,10 @@ class PrimitiveC : public mindspore::Primitive { schema::QuantType quant_type() const; + bool IsEnableHuffmanCode() const; + + void SetEnableHuffmanCode(bool enableHuffmanCode); + virtual int InferShape(std::vector inputs, std::vector outputs); bool infer_flag() const; @@ -154,6 +158,7 @@ class PrimitiveC : public mindspore::Primitive { schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; bool infer_flag_ = true; int op_type_ = OP_TYPE_NOT_SET; + bool enableHuffmanCode = false; }; std::shared_ptr GetReturnPrim(); diff --git a/mindspore/lite/src/tensor.cc b/mindspore/lite/src/tensor.cc index a91893dfa8..a51ca99f99 100644 --- a/mindspore/lite/src/tensor.cc +++ b/mindspore/lite/src/tensor.cc @@ -367,6 +367,10 @@ std::vector Tensor::quant_clusters() const { return this->quant_clusters_ void Tensor::set_quant_clusters(const std::vector &clusters) { this->quant_clusters_ = clusters; } +bool Tensor::IsEnableHuffmanCode() const { return enableHuffmanCode; } + +void Tensor::SetEnableHuffmanCode(bool enableHuffmanCode) { this->enableHuffmanCode = enableHuffmanCode; } + std::vector TensorVectorCast(const std::vector &src) { std::vector target(src.size()); std::transform(src.begin(), src.end(), target.begin(), [](Tensor *t) { return dynamic_cast(t); }); diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index 555641920d..4b126597df 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -149,6 +149,10 @@ class Tensor : public mindspore::tensor::MSTensor { void set_quant_clusters(const std::vector &clusters); + bool IsEnableHuffmanCode() const; + + void SetEnableHuffmanCode(bool enableHuffmanCode); + virtual bool IsConst() const { return (this->category_ == CONST_TENSOR || this->category_ == CONST_SCALAR) && this->data_ != nullptr; } @@ -198,6 +202,7 @@ class Tensor : public mindspore::tensor::MSTensor { std::vector quant_clusters_; mindspore::lite::Allocator *allocator_ = nullptr; Tensor *root_tensor_ = nullptr; + bool enableHuffmanCode = false; }; inline size_t DataTypeSize(const TypeId type) { diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 8e118b2177..9cf5ba5cfe 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -132,6 +132,7 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/lite_kernel.cc ${LITE_DIR}/src/lite_session.cc ${LITE_DIR}/src/dequant.cc + ${LITE_DIR}/src/huffman_decode.cc ${LITE_DIR}/src/sub_graph_kernel.cc ${LITE_DIR}/src/lite_model.cc ${LITE_DIR}/src/scheduler.cc diff --git a/mindspore/lite/test/models_tflite_weightquant.cfg b/mindspore/lite/test/models_tflite_weightquant.cfg index dfc9d8397a..c8d4a164ff 100644 --- a/mindspore/lite/test/models_tflite_weightquant.cfg +++ b/mindspore/lite/test/models_tflite_weightquant.cfg @@ -1 +1,2 @@ -ml_face_openclose.tflite +ml_face_openclose.tflite 0.5 +hiai_ghostnet.tflite 5 diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh index 1bcc2a8be2..e45107da7e 100755 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -221,13 +221,14 @@ function Run_Converter() { # Convert tflite weightquant models: while read line; do - model_name=${line} - if [[ $model_name == \#* ]]; then + weight_quant_line_info=${line} + if [[ $weight_quant_line_info == \#* ]]; then continue fi + model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` echo ${model_name} >> "${run_converter_log_file}" - echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightSize=500 --quantWeightChannel=16' >> "${run_converter_log_file}" - ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightSize=500 --quantWeightChannel=16 + echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true' >> "${run_converter_log_file}" + ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true if [ $? = 0 ]; then converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} else @@ -515,15 +516,17 @@ function Run_x86() { # Run tflite weight quantization converted models: while read line; do - model_name=${line} - if [[ $model_name == \#* ]]; then + weight_quant_line_info=${line} + if [[ $weight_quant_line_info == \#* ]]; then continue fi + model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${weight_quant_line_info}|awk -F ' ' '{print $2}'` echo ${model_name} >> "${run_x86_log_file}" echo 'cd '${x86_path}'/mindspore-lite-'${version}'-inference-linux-x64' >> "${run_x86_log_file}" cd ${x86_path}/mindspore-lite-${version}-inference-linux-x64 || return 1 - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out' >> "${run_x86_log_file}" - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out >> "${run_x86_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out --accuracyThreshold=${accuracy_limit}' >> "${run_x86_log_file}" + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out --accuracyThreshold=${accuracy_limit}>> "${run_x86_log_file}" if [ $? = 0 ]; then run_result='x86: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} else @@ -781,15 +784,17 @@ function Run_x86_sse() { # Run tflite weight quantization converted models: while read line; do - model_name=${line} - if [[ $model_name == \#* ]]; then + weight_quant_line_info=${line} + if [[ $weight_quant_line_info == \#* ]]; then continue fi + model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${weight_quant_line_info}|awk -F ' ' '{print $2}'` echo ${model_name} >> "${run_x86_sse_log_file}" echo 'cd '${x86_path}'/mindspore-lite-'${version}'-inference-linux-x64-sse' >> "${run_x86_sse_log_file}" cd ${x86_path}/mindspore-lite-${version}-inference-linux-x64-sse || return 1 - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out' >> "${run_x86_sse_log_file}" - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out >> "${run_x86_sse_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out --accuracyThreshold=${accuracy_limit}' >> "${run_x86_sse_log_file}" + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out --accuracyThreshold=${accuracy_limit} >> "${run_x86_sse_log_file}" if [ $? = 0 ]; then run_result='x86_sse: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} else @@ -1047,15 +1052,17 @@ function Run_x86_avx() { # Run tflite weight quantization converted models: while read line; do - model_name=${line} - if [[ $model_name == \#* ]]; then + weight_quant_line_info=${line} + if [[ $weight_quant_line_info == \#* ]]; then continue fi + model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${weight_quant_line_info}|awk -F ' ' '{print $2}'` echo ${model_name} >> "${run_x86_avx_log_file}" echo 'cd '${x86_path}'/mindspore-lite-'${version}'-inference-linux-x64-avx' >> "${run_x86_avx_log_file}" cd ${x86_path}/mindspore-lite-${version}-inference-linux-x64-avx || return 1 - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out' >> "${run_x86_avx_log_file}" - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out >> "${run_x86_avx_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out --accuracyThreshold=${accuracy_limit}' >> "${run_x86_avx_log_file}" + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out --accuracyThreshold=${accuracy_limit} >> "${run_x86_avx_log_file}" if [ $? = 0 ]; then run_result='x86_avx: '${model_name}' pass'; echo ${run_result} >> ${run_benchmark_result_file} else diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 1ce2081858..ba3a8bd206 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -251,19 +251,10 @@ int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgrap return RET_OK; } -int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, - const size_t &subgraph_index, bool keep_graph, bool copy_primitive, - const std::shared_ptr &partial_anode) { +int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, + const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive, + const std::unique_ptr &sub_graphT) { int ret = RET_OK; - meta_graphT->subGraph.emplace_back(std::make_unique()); - auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index); - auto subgraph_name = func_graph->get_attr("graph_name"); - MS_ASSERT(subgraph_name != nullptr); - sub_graphT->name = GetValue(subgraph_name); - auto fmk = func_graph->get_attr("fmk"); - MS_ASSERT(fmk != nullptr); - meta_graphT->fmkType = GetValue(fmk); - auto cnodes = func_graph->GetOrderedCnodes(); for (const auto &cnode : cnodes) { auto primitive_c = GetValueNode>(cnode->input(0)); @@ -357,6 +348,23 @@ int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::uniqu meta_graphT->nodes.push_back(std::move(node)); meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx++); } + return ret; +} + +int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, + const size_t &subgraph_index, bool keep_graph, bool copy_primitive, + const std::shared_ptr &partial_anode) { + int ret = RET_OK; + meta_graphT->subGraph.emplace_back(std::make_unique()); + auto &sub_graphT = meta_graphT->subGraph.at(subgraph_index); + auto subgraph_name = func_graph->get_attr("graph_name"); + MS_ASSERT(subgraph_name != nullptr); + sub_graphT->name = GetValue(subgraph_name); + auto fmk = func_graph->get_attr("fmk"); + MS_ASSERT(fmk != nullptr); + meta_graphT->fmkType = GetValue(fmk); + + ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive, sub_graphT); if (ret != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); return ret; @@ -454,6 +462,7 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr &input_anode, } int AnfExporter::ConvertInputParameter(const std::shared_ptr &input_anode, + const std::shared_ptr &primitive_c, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode) { auto paramNode = input_anode->cast(); @@ -499,156 +508,182 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr &input_ano } paramTensor->name = input_name; + if (primitive_c->IsEnableHuffmanCode() && paramTensor->dataType == kNumberTypeInt8) { + paramTensor->enableHuffmanCode = true; + } node_id_map_[input_name] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); meta_graphT->allTensors.emplace_back(std::move(paramTensor)); return RET_OK; } +int AnfExporter::ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT) { + int ret; + auto valueAbstract = valueNode->abstract(); + auto abstractTensor = utils::cast(valueAbstract); + if (abstractTensor == nullptr || abstractTensor->element() == nullptr) { + MS_LOG(ERROR) << "abstractTensor or abstractTensor->element() is nullptr"; + return RET_ERROR; + } + auto typePtr = abstractTensor->element()->GetTypeTrack(); + (*paramTensor)->dataType = typePtr->type_id(); + auto shape_vector = utils::cast(abstractTensor->BuildShape())->shape(); + std::vector dims; + (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), + [](const int64_t &value) { return static_cast(value); }); + (*paramTensor)->dims = dims; +#ifdef SUPPORT_TRAIN + if ((*paramTensor)->dims.size() == 0) (*paramTensor)->dims = {1}; +#endif + (*paramTensor)->nodeType = schema::NodeType::NodeType_ValueNode; + auto data = value->cast(); + (*paramTensor)->data.resize(data->Size()); + ret = memcpy_s((*paramTensor)->data.data(), data->Size(), data->data_c(), data->Size()); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s error."; + return RET_ERROR; + } + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + return ret; +} +int AnfExporter::ProcessInt32OrInt64Imm(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT) { + int ret; + // data of int64 is converted to int32 here. + (*paramTensor)->dataType = kNumberTypeInt32; + (*paramTensor)->dims = {1}; + (*paramTensor)->nodeType = schema::NodeType::NodeType_ValueNode; + int real_data = CastToInt(value).front(); + (*paramTensor)->data.resize(sizeof(int32_t)); + ret = memcpy_s((*paramTensor)->data.data(), sizeof(int32_t), &real_data, sizeof(int32_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s error."; + return RET_ERROR; + } + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + return ret; +} +void AnfExporter::ProcessBoolImm(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT) { + auto valueAbstract = valueNode->abstract(); + auto abstractScalar = utils::cast(valueAbstract); + auto typePtr = abstractScalar->GetTypeTrack(); + (*paramTensor)->dataType = typePtr->type_id(); + (*paramTensor)->dims = {1}; + (*paramTensor)->nodeType = schema::NodeType_ValueNode; + auto data = value->cast(); + (*paramTensor)->data.emplace_back(data->value()); + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); +} +void AnfExporter::ProcessInt(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { + (*paramTensor)->dataType = kNumberTypeInt32; + (*paramTensor)->dims = {1}; + (*paramTensor)->nodeType = schema::NodeType_ValueNode; + (*paramTensor)->data.emplace_back(kNumberTypeInt32); + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); +} +int AnfExporter::ProcessValueSequence(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT) { + int ret = RET_OK; + auto valueAbstract = valueNode->abstract(); + auto abstractSequnce = utils::cast(valueAbstract); + if (abstractSequnce->isa()) { + auto abstractTuple = utils::cast(valueAbstract); + auto x_shape_data = abstractTuple->elements(); + std::vector shape; + for (std::size_t i = 0; i < abstractTuple->size(); ++i) { + auto value_track = x_shape_data[i]->GetValueTrack(); + MS_ASSERT(value_track != nullptr); + if (value_track->isa()) { + shape.push_back((GetValue(value_track))); + } else if (value_track->isa()) { + shape.push_back((GetValue(value_track))); + } else { + MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << "."; + return RET_ERROR; + } + } + (*paramTensor)->dataType = kNumberTypeInt32; + (*paramTensor)->dims = {static_cast(shape.size())}; + (*paramTensor)->nodeType = schema::NodeType_ValueNode; + (*paramTensor)->data.resize(shape.size() * sizeof(int)); + ret = memcpy_s((*paramTensor)->data.data(), shape.size() * sizeof(int32_t), shape.data(), + shape.size() * sizeof(int32_t)); + if (ret != RET_OK) { + MS_LOG(ERROR) << "memcpy_s data into paramTensor failed."; + return RET_ERROR; + } + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + } + return ret; +} +int AnfExporter::ProcessParamValueLite(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT) { + int ret; + auto valueLite = std::dynamic_pointer_cast(value); + (*paramTensor)->data.resize(valueLite->tensor_size()); + (*paramTensor)->format = schema::Format(valueLite->format()); + (*paramTensor)->dataType = valueLite->tensor_type(); + (*paramTensor)->dims = valueLite->tensor_shape(); +#ifdef SUPPORT_TRAIN + if ((*paramTensor)->dims.size() == 0) { + (*paramTensor)->dims = {1}; + } +#endif + ret = memcpy_s((*paramTensor)->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(), + valueLite->tensor_size()); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s data into tensor failed."; + return RET_ERROR; + } + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(*paramTensor)); + return ret; +} + int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_anode, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode) { auto valueNode = input_anode->cast(); auto paramTensor = std::make_unique(); auto value = valueNode->value(); + int ret = RET_OK; #ifdef SUPPORT_TRAIN paramTensor->name = valueNode->fullname_with_scope(); #endif if (value->isa()) { - auto valueAbstract = valueNode->abstract(); - auto abstractTensor = utils::cast(valueAbstract); - if (abstractTensor == nullptr || abstractTensor->element() == nullptr) { - MS_LOG(ERROR) << "abstractTensor or abstractTensor->element() is nullptr"; - return RET_ERROR; - } - auto typePtr = abstractTensor->element()->GetTypeTrack(); - paramTensor->dataType = typePtr->type_id(); - auto shape_vector = utils::cast(abstractTensor->BuildShape())->shape(); - std::vector dims; - (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), - [](const int64_t &value) { return static_cast(value); }); - paramTensor->dims = dims; -#ifdef SUPPORT_TRAIN - if (paramTensor->dims.size() == 0) paramTensor->dims = {1}; -#endif - paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; - auto data = value->cast(); - paramTensor->data.resize(data->Size()); - auto ret = memcpy_s(paramTensor->data.data(), data->Size(), data->data_c(), data->Size()); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s error."; - return RET_ERROR; - } - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + ret = ProcessTensor(valueNode, ¶mTensor, value, output_cnode, meta_graphT); } else if (value->isa() || value->isa()) { - auto valueAbstract = valueNode->abstract(); - auto abstractScalar = utils::cast(valueAbstract); - auto typePtr = abstractScalar->GetTypeTrack(); - // data of int64 is converted to int32 here. - paramTensor->dataType = kNumberTypeInt32; - paramTensor->dims = {1}; - paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; - int real_data = CastToInt(value).front(); - paramTensor->data.resize(sizeof(int32_t)); - auto ret = memcpy_s(paramTensor->data.data(), sizeof(int32_t), &real_data, sizeof(int32_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s error."; - return RET_ERROR; - } - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + ret = ProcessInt32OrInt64Imm(valueNode, ¶mTensor, value, output_cnode, meta_graphT); } else if (value->isa()) { - auto valueAbstract = valueNode->abstract(); - auto abstractScalar = utils::cast(valueAbstract); - auto typePtr = abstractScalar->GetTypeTrack(); - paramTensor->dataType = typePtr->type_id(); - paramTensor->dims = {1}; - paramTensor->nodeType = schema::NodeType_ValueNode; - auto data = value->cast(); - paramTensor->data.emplace_back(data->value()); - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + ProcessBoolImm(valueNode, ¶mTensor, value, output_cnode, meta_graphT); } else if (value->isa()) { - paramTensor->dataType = kNumberTypeInt32; - paramTensor->dims = {1}; - paramTensor->nodeType = schema::NodeType_ValueNode; - paramTensor->data.emplace_back(kNumberTypeInt32); - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + ProcessInt(valueNode, ¶mTensor, output_cnode, meta_graphT); } else if (value->isa()) { - auto valueAbstract = valueNode->abstract(); - auto abstractSequnce = utils::cast(valueAbstract); - if (abstractSequnce->isa()) { - auto abstractTuple = utils::cast(valueAbstract); - auto x_shape_data = abstractTuple->elements(); - std::vector shape; - for (std::size_t i = 0; i < abstractTuple->size(); ++i) { - auto value_track = x_shape_data[i]->GetValueTrack(); - MS_ASSERT(value_track != nullptr); - if (value_track->isa()) { - shape.push_back((GetValue(value_track))); - } else if (value_track->isa()) { - shape.push_back((GetValue(value_track))); - } else { - MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << "."; - return RET_ERROR; - } - } - auto typePtr = abstractTuple->elements()[0]->GetTypeTrack(); - paramTensor->dataType = kNumberTypeInt32; - paramTensor->dims = {static_cast(shape.size())}; - paramTensor->nodeType = schema::NodeType_ValueNode; - paramTensor->data.resize(shape.size() * sizeof(int)); - auto ret = memcpy_s(paramTensor->data.data(), shape.size() * sizeof(int32_t), shape.data(), - shape.size() * sizeof(int32_t)); - if (ret != RET_OK) { - MS_LOG(ERROR) << "memcpy_s data into paramTensor failed."; - return RET_ERROR; - } - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(paramTensor)); - } - } else if (value->isa()) { - auto valueAbstract = valueNode->abstract(); - auto abstractScalar = utils::cast(valueAbstract); - auto typePtr = abstractScalar->GetTypeTrack(); - paramTensor->dataType = typePtr->type_id(); - paramTensor->dims = {1}; - paramTensor->nodeType = schema::NodeType_ValueNode; - auto data = value->cast(); - paramTensor->data.emplace_back(data->value()); - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + ret = ProcessValueSequence(valueNode, ¶mTensor, value, output_cnode, meta_graphT); } else if (value->isa()) { MS_LOG(INFO) << "Value is a number."; return RET_OK; } else if (value->isa()) { - auto valueLite = std::dynamic_pointer_cast(value); - paramTensor->data.resize(valueLite->tensor_size()); - paramTensor->format = schema::Format(valueLite->format()); - paramTensor->dataType = valueLite->tensor_type(); - paramTensor->dims = valueLite->tensor_shape(); -#ifdef SUPPORT_TRAIN - if (paramTensor->dims.size() == 0) { - paramTensor->dims = {1}; - } -#endif - auto ret = memcpy_s(paramTensor->data.data(), valueLite->tensor_size() * sizeof(uint8_t), valueLite->tensor_addr(), - valueLite->tensor_size()); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s data into tensor failed."; - return RET_ERROR; - } - node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + ret = ProcessParamValueLite(valueNode, ¶mTensor, value, output_cnode, meta_graphT); } else if (value->isa()) { MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; return RET_OK; @@ -656,7 +691,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano MS_LOG(ERROR) << "Not support value type , need add support."; return RET_ERROR; } - return RET_OK; + return ret; } int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, @@ -666,6 +701,11 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptrinputs().size() <= 1) { return RET_OK; } + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope(); + return RET_ERROR; + } bool is_graph_input = false; for (size_t i = 1; i < cnode->inputs().size(); i++) { auto input_node = cnode->input(i); @@ -676,7 +716,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptrisa()) { - auto ret = ConvertInputParameter(input_node, meta_graphT, fb_node); + auto ret = ConvertInputParameter(input_node, primitive_c, meta_graphT, fb_node); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvertInputParameter failed"; return ret; diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index 877255a59d..64bae9af1e 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -45,10 +45,27 @@ class AnfExporter { protected: int ConvertInputCNode(const std::shared_ptr &input_anode, schema::CNodeT *output_cnode); - int ConvertInputParameter(const std::shared_ptr &input_anode, + int ConvertInputParameter(const std::shared_ptr &input_anode, const std::shared_ptr &primitive, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); int ConvertInputValueNode(const std::shared_ptr &input_anode, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); + int ProcessTensor(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT); + int ProcessInt32OrInt64Imm(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT); + void ProcessBoolImm(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT); + void ProcessInt(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); + int ProcessValueSequence(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT); + int ProcessParamValueLite(const ValueNodePtr &valueNode, std::unique_ptr *paramTensor, + const std::shared_ptr &value, schema::CNodeT *output_cnode, + const std::unique_ptr &meta_graphT); int SetGraphInputIndex(const std::unique_ptr &meta_graphT, const size_t &subgraph_index); int SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index, const std::unique_ptr &meta_graphT, @@ -58,6 +75,9 @@ class AnfExporter { static int ConvertQuantParam(const std::unique_ptr &meta_graph, const std::shared_ptr &primitive, const std::unique_ptr &dst_node); + int Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, + const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive, + const std::unique_ptr &sub_graphT); int ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT, const size_t &subgraph_index, bool keep_graph, bool copy_primitive, const std::shared_ptr &partial_anode = nullptr); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 262b581354..6e0c50920d 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -100,6 +100,7 @@ set(LITE_SRC ${SRC_DIR}/lite_model.cc ${SRC_DIR}/errorcode.cc ${SRC_DIR}/dequant.cc + ${SRC_DIR}/huffman_decode.cc ) if(SUPPORT_TRAIN) set(LITE_SRC diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 28407c2d75..e15711a446 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -51,6 +51,7 @@ #include "tools/optimizer/graph/functionalize_control_op_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" +#include "tools/converter/quantizer/huffman_encode.h" #include "tools/converter/quantizer/weight_quantizer.h" using std::string; @@ -252,6 +253,19 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla return RET_OK; } +int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph) { + if (config->quantType == schema::QuantType_WeightQuant && config->bitNum == "8" && config->enableHuffmanCode) { + auto huffman_encode = std::make_unique(); + auto status = huffman_encode->DoHuffmanEncode(new_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "Huffman encode failed."; + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + return RET_ERROR; + } + } + return RET_OK; +} + FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { MS_ASSERT(nullptr != old_graph); if (config == nullptr) { @@ -305,6 +319,13 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap MS_LOG(ERROR) << "Do Quantize failed."; return nullptr; } + + status = DoHuffmanEncode(config, new_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "Do HuffmanCode failed."; + return nullptr; + } + return new_graph; } diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index e4de7d5d3d..2491cf32b7 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -58,6 +58,8 @@ class AnfTransform { int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); + + int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 145f5662da..940b03968d 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -42,53 +42,17 @@ Flags::Flags() { AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0"); AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16"); AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", ""); + AddFlag(&Flags::enableHuffmanCodeIn, "enableHuffmanCode", + "whether the weight quant model is going to use huffman code." + "true | false", + "false"); AddFlag(&Flags::trainModelIn, "trainModel", "whether the model is going to be trained on device." "true | false", "false"); } -int Flags::Init(int argc, const char **argv) { - if (argc == 1) { - std::cout << this->Usage() << std::endl; - return RET_SUCCESS_EXIT; - } - Option err = this->ParseFlags(argc, argv); - - if (err.IsSome()) { - std::cerr << err.Get(); - std::cerr << this->Usage() << std::endl; - return RET_INPUT_PARAM_INVALID; - } - - if (this->help) { - std::cout << this->Usage() << std::endl; - return RET_SUCCESS_EXIT; - } - if (this->modelFile.empty()) { - std::cerr << "INPUT MISSING: model file path is necessary"; - return RET_INPUT_PARAM_INVALID; - } - if (this->outputFile.empty()) { - std::cerr << "INPUT MISSING: output file path is necessary"; - return RET_INPUT_PARAM_INVALID; - } - -#ifdef _WIN32 - replace(this->outputFile.begin(), this->outputFile.end(), '/', '\\'); -#endif - - if (this->outputFile.rfind('/') == this->outputFile.length() - 1 || - this->outputFile.rfind('\\') == this->outputFile.length() - 1) { - std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path"; - return RET_INPUT_PARAM_INVALID; - } - - if (this->fmkIn.empty()) { - std::cerr << "INPUT MISSING: fmk is necessary"; - return RET_INPUT_PARAM_INVALID; - } - +int Flags::InitInputOutputDataType() { if (this->inputDataTypeIn == "FLOAT") { this->inputDataType = TypeId::kNumberTypeFloat32; } else if (this->inputDataTypeIn == "INT8") { @@ -117,7 +81,10 @@ int Flags::Init(int argc, const char **argv) { this->outputDataTypeIn.c_str(); return RET_INPUT_PARAM_INVALID; } + return RET_OK; +} +int Flags::InitFmk() { if (this->fmkIn == "CAFFE") { this->fmk = FmkType_CAFFE; } else if (this->fmkIn == "MINDIR") { @@ -137,7 +104,10 @@ int Flags::Init(int argc, const char **argv) { std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; return RET_INPUT_PARAM_INVALID; } + return RET_OK; +} +int Flags::InitQuantType() { if (this->quantTypeIn == "WeightQuant") { this->quantType = QuantType_WeightQuant; } else if (this->quantTypeIn == "PostTraining") { @@ -148,7 +118,22 @@ int Flags::Init(int argc, const char **argv) { std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining"; return RET_INPUT_PARAM_INVALID; } + return RET_OK; +} + +int Flags::InitHuffmanCode() { + if (this->enableHuffmanCodeIn == "true") { + this->enableHuffmanCode = true; + } else if (this->enableHuffmanCodeIn == "false") { + this->enableHuffmanCode = false; + } else { + std::cerr << "INPUT ILLEGAL: trainModel must be true|false "; + return RET_INPUT_PARAM_INVALID; + } + return RET_OK; +} +int Flags::InitTrainModel() { if (this->trainModelIn == "true") { this->trainModel = true; } else if (this->trainModelIn == "false") { @@ -160,24 +145,99 @@ int Flags::Init(int argc, const char **argv) { if (this->trainModel) { if (this->fmk != FmkType_MS) { - std::cerr << "INPUT ILLEGAL: train model convertor supporting only MINDIR format"; + std::cerr << "INPUT ILLEGAL: train model converter supporting only MINDIR format"; return RET_INPUT_PARAM_INVALID; } if ((this->inputDataType != TypeId::kNumberTypeFloat32) && (this->inputDataType != TypeId::kTypeUnknown)) { - std::cerr << "INPUT ILLEGAL: train model convertor supporting only FP32 input tensors"; + std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 input tensors"; return RET_INPUT_PARAM_INVALID; } if ((this->outputDataType != TypeId::kNumberTypeFloat32) && (this->outputDataType != TypeId::kTypeUnknown)) { - std::cerr << "INPUT ILLEGAL: train model convertor supporting only FP32 output tensors"; + std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors"; return RET_INPUT_PARAM_INVALID; } if (this->quantType != QuantType_QUANT_NONE) { - std::cerr << "INPUT ILLEGAL: train model convertor is not supporting quantization"; + std::cerr << "INPUT ILLEGAL: train model converter is not supporting quantization"; return RET_INPUT_PARAM_INVALID; } } return RET_OK; } + +int Flags::Init(int argc, const char **argv) { + int ret; + if (argc == 1) { + std::cout << this->Usage() << std::endl; + return RET_SUCCESS_EXIT; + } + Option err = this->ParseFlags(argc, argv); + + if (err.IsSome()) { + std::cerr << err.Get(); + std::cerr << this->Usage() << std::endl; + return RET_INPUT_PARAM_INVALID; + } + + if (this->help) { + std::cout << this->Usage() << std::endl; + return RET_SUCCESS_EXIT; + } + if (this->modelFile.empty()) { + std::cerr << "INPUT MISSING: model file path is necessary"; + return RET_INPUT_PARAM_INVALID; + } + if (this->outputFile.empty()) { + std::cerr << "INPUT MISSING: output file path is necessary"; + return RET_INPUT_PARAM_INVALID; + } + +#ifdef _WIN32 + replace(this->outputFile.begin(), this->outputFile.end(), '/', '\\'); +#endif + + if (this->outputFile.rfind('/') == this->outputFile.length() - 1 || + this->outputFile.rfind('\\') == this->outputFile.length() - 1) { + std::cerr << "INPUT ILLEGAL: outputFile must be a valid file path"; + return RET_INPUT_PARAM_INVALID; + } + + if (this->fmkIn.empty()) { + std::cerr << "INPUT MISSING: fmk is necessary"; + return RET_INPUT_PARAM_INVALID; + } + + ret = InitInputOutputDataType(); + if (ret != RET_OK) { + std::cerr << "Init input output datatype failed."; + return RET_INPUT_PARAM_INVALID; + } + + ret = InitFmk(); + if (ret != RET_OK) { + std::cerr << "Init fmk failed."; + return RET_INPUT_PARAM_INVALID; + } + + ret = InitQuantType(); + if (ret != RET_OK) { + std::cerr << "Init quant type failed."; + return RET_INPUT_PARAM_INVALID; + } + + ret = InitHuffmanCode(); + if (ret != RET_OK) { + std::cerr << "Init huffman code failed."; + return RET_INPUT_PARAM_INVALID; + } + + ret = InitTrainModel(); + if (ret != RET_OK) { + std::cerr << "Init train model failed."; + return RET_INPUT_PARAM_INVALID; + } + + return RET_OK; +} } // namespace converter } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 1c1d6fea34..214b98b514 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -45,6 +45,16 @@ class Flags : public virtual mindspore::lite::FlagParser { ~Flags() override = default; + int InitInputOutputDataType(); + + int InitFmk(); + + int InitQuantType(); + + int InitHuffmanCode(); + + int InitTrainModel(); + int Init(int argc, const char **argv); public: @@ -70,6 +80,8 @@ class Flags : public virtual mindspore::lite::FlagParser { std::string bitNum; std::string configFile; std::string quantWeightChannel; + std::string enableHuffmanCodeIn; + bool enableHuffmanCode = false; std::string trainModelIn; bool trainModel = false; }; diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index 0bfa01911f..f2bf1f52e0 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -10,6 +10,7 @@ file(GLOB QUANTIZER ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/huffman_encode.cc ) set_property(SOURCE ${QUANTIZER} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(quantizer_mid OBJECT ${QUANTIZER}) diff --git a/mindspore/lite/tools/converter/quantizer/huffman_encode.cc b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc new file mode 100644 index 0000000000..4dad3fe68e --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc @@ -0,0 +1,281 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/huffman_encode.h" + +#include +#include +#include +#include + +#include "securec/include/securec.h" +#include "src/param_value_lite.h" + +namespace mindspore { +namespace lite { + +STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { + auto cnodes = func_graph->GetOrderedCnodes(); + STATUS status; + for (auto &cnode : cnodes) { + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope(); + return RET_ERROR; + } + if (primitive_c->quant_type() != schema::QuantType_WeightQuant) { + continue; + } + for (size_t i = 1; i < cnode->inputs().size(); i++) { + auto input_node = cnode->input(i); + if (!input_node->isa()) { + continue; + } + auto abstract_base = input_node->abstract(); + if (abstract_base == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope(); + return RET_ERROR; + } + if (!utils::isa(abstract_base)) { + MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope(); + return RET_ERROR; + } + auto abstract_tensor = utils::cast(abstract_base); + if (abstract_tensor->element() == nullptr) { + MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope(); + return RET_ERROR; + } + auto tensor_type = abstract_tensor->element()->GetTypeTrack(); + MS_ASSERT(tensor_type != nullptr); + auto tensor_type_id = tensor_type->type_id(); + if (tensor_type_id != kNumberTypeInt8) { + continue; + } + auto param_node = input_node->cast(); + if (param_node == nullptr) { + MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope(); + return RET_ERROR; + } + if (!param_node->has_default()) { + MS_LOG(WARNING) << "param_node don't have default: " << cnode->fullname_with_scope(); + continue; + } + ParamValueLitePtr param_value = std::static_pointer_cast(param_node->default_param()); + size_t elem_count = param_value->tensor_shape_size(); + auto *raw_datas = static_cast(param_value->tensor_addr()); + if (raw_datas == nullptr) { + MS_LOG(ERROR) << "rawDatas is nullptr"; + return RET_ERROR; + } + HuffmanPriorityQueue pq; + status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetHuffmanPriorityQueue failed"; + return status; + } + status = BuildHuffmanTree(&pq); + if (status != RET_OK) { + MS_LOG(ERROR) << "BuildHuffmanTree failed"; + return status; + } + status = DoHuffmanCompress(raw_datas, elem_count); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoHuffmanCompress failed"; + return status; + } + size_t ch_size = huffman_encoded_str_.length(); + if (ch_size < elem_count) { + auto encode_data = new (std::nothrow) char[ch_size]; + if (encode_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed."; + return RET_MEMORY_FAILED; + } + if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed."; + delete[] encode_data; + return RET_MEMORY_FAILED; + } + param_value->SetTensorData(encode_data, ch_size); + primitive_c->SetEnableHuffmanCode(true); + } + huffman_encoded_str_.clear(); + huffman_table_.clear(); + } + } + return RET_SUCCESS; +} + +STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) { + MS_ASSERT(data != nullptr); + + std::map freq_map; + + for (size_t i = 0; i < data_size; i++) { + freq_map[data[i]]++; + } + + for (auto &kv : freq_map) { + if (kv.second <= 0) { + continue; + } + auto node = new (std::nothrow) HuffmanNode(); + if (node == nullptr) { + MS_LOG(ERROR) << "new HuffmanNode failed."; + return RET_MEMORY_FAILED; + } + this->huffman_nodes_.push_back(node); + node->key = kv.first; + node->freq = kv.second; + node->code = ""; + node->left = nullptr; + node->right = nullptr; + node->parent = nullptr; + + pq->push(node); + } + + // insert pseudo-EOF + auto node = new (std::nothrow) HuffmanNode(); + if (node == nullptr) { + MS_LOG(ERROR) << "new HuffmanNode failed."; + return RET_MEMORY_FAILED; + } + this->huffman_nodes_.push_back(node); + node->key = PSEUDO_EOF; + node->freq = 1; + node->code = ""; + node->left = nullptr; + node->right = nullptr; + node->parent = nullptr; + + pq->push(node); + + return RET_OK; +} + +void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) { + if (is_left_node) { + node->code = node->parent->code + "0"; + } else { + node->code = node->parent->code + "1"; + } + + if (node->left == nullptr && node->right == nullptr) { + huffman_table_[node->key] = node->code; + } else { + if (node->left != nullptr) { + GenerateHuffmanTable(node->left, true); + } + if (node->right != nullptr) { + GenerateHuffmanTable(node->right, false); + } + } +} + +STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { + HuffmanNodePtr root = nullptr; + + while (!pq->empty()) { + HuffmanNodePtr first = pq->top(); + pq->pop(); + + if (pq->empty()) { + root = first; + break; + } + + HuffmanNodePtr second = pq->top(); + pq->pop(); + + auto new_node = new (std::nothrow) HuffmanNode(); + if (new_node == nullptr) { + MS_LOG(ERROR) << "new HuffmanNode failed."; + return RET_MEMORY_FAILED; + } + this->huffman_nodes_.push_back(new_node); + new_node->freq = first->freq + second->freq; + new_node->left = first; + new_node->right = second; + first->parent = new_node; + second->parent = new_node; + + pq->push(new_node); + } + + if (root == nullptr) { + MS_LOG(ERROR) << "huffman tree root node is nullptr."; + return RET_ERROR; + } + + if (root->left != nullptr) { + GenerateHuffmanTable(root->left, true); + } + if (root->right != nullptr) GenerateHuffmanTable(root->right, false); + + return RET_OK; +} + +STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) { + unsigned char out_c; + string code_str; + std::map::iterator iter; + std::vector encode_str = {"", "", ""}; + + huffman_encoded_str_.clear(); + for (iter = huffman_table_.begin(); iter != huffman_table_.end(); ++iter) { + encode_str[0] += std::to_string(iter->first) + " "; + encode_str[1] += iter->second + " "; + } + + for (size_t i = 0; i < data_size; i++) { + auto raw_num = input_datas[i]; + iter = huffman_table_.find(raw_num); + if (iter != huffman_table_.end()) { + code_str += iter->second; + } else { + MS_LOG(ERROR) << "Can't find the huffman code " << raw_num; + return RET_ERROR; + } + } + iter = huffman_table_.find(PSEUDO_EOF); + if (iter != huffman_table_.end()) { + code_str += iter->second; + } else { + MS_LOG(ERROR) << "Can't find the huffman code pseudo-EOF"; + return RET_ERROR; + } + out_c = 0; + for (size_t i = 0; i < code_str.length(); i++) { + auto tmp_c = code_str[i] == '0' ? 0 : 1; + out_c += tmp_c << (7 - (i % 8)); + if (0 == (i + 1) % 8 || i == code_str.length() - 1) { + encode_str[2] += out_c; + out_c = 0; + } + } + huffman_encoded_str_ = encode_str[0] + "#" + encode_str[1] + "#" + encode_str[2]; + return RET_OK; +} + +huffman_encode::~huffman_encode() { + for (auto &node : this->huffman_nodes_) { + delete node; + } + this->huffman_nodes_.resize(0); +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/huffman_encode.h b/mindspore/lite/tools/converter/quantizer/huffman_encode.h new file mode 100644 index 0000000000..f7418d9ba4 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/huffman_encode.h @@ -0,0 +1,77 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_HUFFMANCODE_HUFFMAN_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_HUFFMANCODE_HUFFMAN_H + +#include +#include +#include +#include +#include +#include +#include +#include "src/common/log_adapter.h" +#include "src/ops/primitive_c.h" +#include "ir/func_graph.h" + +namespace mindspore { +namespace lite { + +using STATUS = int; + +const int PSEUDO_EOF = 128; + +struct HuffmanNode { + int key; + unsigned int freq; + std::string code; + HuffmanNode *left, *right, *parent; +}; +using HuffmanNodePtr = HuffmanNode *; + +struct cmp { + public: + bool operator()(const HuffmanNodePtr &c1, const HuffmanNodePtr &c2) const { return c1->freq > c2->freq; } +}; +using HuffmanPriorityQueue = std::priority_queue, cmp>; + +class huffman_encode { + public: + huffman_encode() = default; + + ~huffman_encode(); + + STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph); + + private: + std::map huffman_table_; + std::string huffman_encoded_str_ = ""; + std::vector huffman_nodes_; + + STATUS GetHuffmanPriorityQueue(const int8_t *input_datas, size_t input_data_size, HuffmanPriorityQueue *pq); + + void GenerateHuffmanTable(HuffmanNodePtr node, bool is_left_node); + + STATUS BuildHuffmanTree(HuffmanPriorityQueue *pq); + + STATUS DoHuffmanCompress(const int8_t *input_datas, size_t data_size); +}; + +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_HUFFMANCODE_HUFFMAN_H