diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 1e0f72d6e5..7d04b804be 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -74,7 +74,7 @@ void LiteSession::ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lit dst_tensor->AddQuantParam(quant_arg); } } - dst_tensor->SetEnableHuffmanCode(src_tensor->enableHuffmanCode()); + dst_tensor->set_enable_huffman_code(src_tensor->enableHuffmanCode()); auto quant_clusters = src_tensor->quantClusters(); if (quant_clusters != nullptr) { std::vector clusters; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index f2d8cc36cf..60228a74ac 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -451,9 +451,9 @@ void PrimitiveC::set_quant_type(const schema::QuantType &quant_type) { this->qua schema::QuantType PrimitiveC::quant_type() const { return quant_type_; } -bool PrimitiveC::IsEnableHuffmanCode() const { return enableHuffmanCode; } +bool PrimitiveC::enable_huffman_code() const { return enable_huffman_code_; } -void PrimitiveC::SetEnableHuffmanCode(bool enableHuffmanCode) { this->enableHuffmanCode = enableHuffmanCode; } +void PrimitiveC::set_enable_huffman_code(bool enable_huffman_code) { this->enable_huffman_code_ = enable_huffman_code; } std::shared_ptr GetReturnPrim() { auto return_primitiveT = new (std::nothrow) schema::PrimitiveT; diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index abf41457cc..f972607ab2 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -123,9 +123,9 @@ class PrimitiveC : public mindspore::Primitive { schema::QuantType quant_type() const; - bool IsEnableHuffmanCode() const; + bool enable_huffman_code() const; - void SetEnableHuffmanCode(bool enableHuffmanCode); + void set_enable_huffman_code(bool enable_huffman_code); virtual int InferShape(std::vector inputs, std::vector outputs); @@ -158,7 +158,7 @@ class PrimitiveC : public mindspore::Primitive { schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; bool infer_flag_ = true; int op_type_ = OP_TYPE_NOT_SET; - bool enableHuffmanCode = false; + bool enable_huffman_code_ = false; }; std::shared_ptr GetReturnPrim(); diff --git a/mindspore/lite/src/tensor.cc b/mindspore/lite/src/tensor.cc index 3abcb0968f..0fe2ae7b30 100644 --- a/mindspore/lite/src/tensor.cc +++ b/mindspore/lite/src/tensor.cc @@ -366,9 +366,9 @@ std::vector Tensor::quant_clusters() const { return this->quant_clusters_ void Tensor::set_quant_clusters(const std::vector &clusters) { this->quant_clusters_ = clusters; } -bool Tensor::IsEnableHuffmanCode() const { return enableHuffmanCode; } +bool Tensor::enable_huffman_code() const { return enable_huffman_code_; } -void Tensor::SetEnableHuffmanCode(bool enableHuffmanCode) { this->enableHuffmanCode = enableHuffmanCode; } +void Tensor::set_enable_huffman_code(bool enable_huffman_code) { this->enable_huffman_code_ = enable_huffman_code; } std::vector TensorVectorCast(const std::vector &src) { std::vector target(src.size()); diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index 4b126597df..5f72854980 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -149,9 +149,9 @@ class Tensor : public mindspore::tensor::MSTensor { void set_quant_clusters(const std::vector &clusters); - bool IsEnableHuffmanCode() const; + bool enable_huffman_code() const; - void SetEnableHuffmanCode(bool enableHuffmanCode); + void set_enable_huffman_code(bool enable_huffman_code); virtual bool IsConst() const { return (this->category_ == CONST_TENSOR || this->category_ == CONST_SCALAR) && this->data_ != nullptr; @@ -202,7 +202,7 @@ class Tensor : public mindspore::tensor::MSTensor { std::vector quant_clusters_; mindspore::lite::Allocator *allocator_ = nullptr; Tensor *root_tensor_ = nullptr; - bool enableHuffmanCode = false; + bool enable_huffman_code_ = false; }; inline size_t DataTypeSize(const TypeId type) { diff --git a/mindspore/lite/test/models_mindspore_mixbit.cfg b/mindspore/lite/test/models_mindspore_mixbit.cfg index 287409b77c..f7ba8d439f 100644 --- a/mindspore/lite/test/models_mindspore_mixbit.cfg +++ b/mindspore/lite/test/models_mindspore_mixbit.cfg @@ -1 +1 @@ -efficientnet.mindir 40.64 9.98 +efficientnet.mindir 41.37 9.98 diff --git a/mindspore/lite/test/models_mindspore_weightquant.cfg b/mindspore/lite/test/models_mindspore_weightquant.cfg index 086f95658a..683fef9f59 100644 --- a/mindspore/lite/test/models_mindspore_weightquant.cfg +++ b/mindspore/lite/test/models_mindspore_weightquant.cfg @@ -1,3 +1,3 @@ -retinaface_732_1280_iod.mindir -mobilefacenet_iod.mindir +retinaface_732_1280_iod.mindir 16.9 +mobilefacenet_iod.mindir 13.5 #effnet_iod.mindir diff --git a/mindspore/lite/test/models_tflite_weightquant.cfg b/mindspore/lite/test/models_tflite_weightquant.cfg index c8d4a164ff..c8cc31ed40 100644 --- a/mindspore/lite/test/models_tflite_weightquant.cfg +++ b/mindspore/lite/test/models_tflite_weightquant.cfg @@ -1,2 +1,2 @@ ml_face_openclose.tflite 0.5 -hiai_ghostnet.tflite 5 +hiai_ghostnet.tflite 4.7 diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh index fe8c5ad084..876720606a 100755 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -238,10 +238,11 @@ function Run_Converter() { # Convert mindir weightquant models: while read line; do - model_name=${line} - if [[ $model_name == \#* ]]; then + weight_quant_line_info=${line} + if [[ $weight_quant_line_info == \#* ]]; then continue fi + model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` echo ${model_name} >> "${run_converter_log_file}" echo './converter_lite --fmk=MINDIR --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}' --quantType=WeightQuant --bitNum=8 --quantWeightSize=500 --quantWeightChannel=16' >> "${run_converter_log_file}" ./converter_lite --fmk=MINDIR --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightSize=500 --quantWeightChannel=16 @@ -538,15 +539,17 @@ function Run_x86() { # Run mindir weight quantization converted models: while read line; do - model_name=${line} - if [[ $model_name == \#* ]]; then + weight_quant_line_info=${line} + if [[ $weight_quant_line_info == \#* ]]; then continue fi + model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${weight_quant_line_info}|awk -F ' ' '{print $2}'` echo ${model_name} >> "${run_x86_log_file}" echo 'cd '${x86_path}'/mindspore-lite-'${version}'-inference-linux-x64' >> "${run_x86_log_file}" cd ${x86_path}/mindspore-lite-${version}-inference-linux-x64 || return 1 - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out' >> "${run_x86_log_file}" - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.weightquant.ms.out >> "${run_x86_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out --accuracyThreshold=${accuracy_limit}' >> "${run_x86_log_file}" + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out --accuracyThreshold=${accuracy_limit} >> "${run_x86_log_file}" if [ $? = 0 ]; then run_result='x86: '${model_name}'[weight quant] pass'; echo ${run_result} >> ${run_benchmark_result_file} else @@ -809,15 +812,17 @@ function Run_x86_sse() { # Run mindir weight quantization converted models: while read line; do - model_name=${line} - if [[ $model_name == \#* ]]; then + weight_quant_line_info=${line} + if [[ $weight_quant_line_info == \#* ]]; then continue fi + model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${weight_quant_line_info}|awk -F ' ' '{print $2}'` echo ${model_name} >> "${run_x86_sse_log_file}" echo 'cd '${x86_path}'/mindspore-lite-'${version}'-inference-linux-x64-sse' >> "${run_x86_sse_log_file}" cd ${x86_path}/mindspore-lite-${version}-inference-linux-x64-sse || return 1 - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out' >> "${run_x86_sse_log_file}" - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.weightquant.ms.out >> "${run_x86_sse_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out --accuracyThreshold=${accuracy_limit}' >> "${run_x86_sse_log_file}" + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out --accuracyThreshold=${accuracy_limit} >> "${run_x86_sse_log_file}" if [ $? = 0 ]; then run_result='x86_sse: '${model_name}'[weight quant] pass'; echo ${run_result} >> ${run_benchmark_result_file} else @@ -1081,15 +1086,17 @@ function Run_x86_avx() { # Run mindir weight quantization converted models: while read line; do - model_name=${line} - if [[ $model_name == \#* ]]; then + weight_quant_line_info=${line} + if [[ $weight_quant_line_info == \#* ]]; then continue fi + model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${weight_quant_line_info}|awk -F ' ' '{print $2}'` echo ${model_name} >> "${run_x86_avx_log_file}" echo 'cd '${x86_path}'/mindspore-lite-'${version}'-inference-linux-x64-avx' >> "${run_x86_avx_log_file}" cd ${x86_path}/mindspore-lite-${version}-inference-linux-x64-avx || return 1 - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out' >> "${run_x86_avx_log_file}" - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.weightquant.ms.out >> "${run_x86_avx_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile='${ms_models_path}'/'${model_name}'.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/'${model_name}'.ms.out --accuracyThreshold=${accuracy_limit}' >> "${run_x86_avx_log_file}" + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib;./benchmark/benchmark --modelFile=${ms_models_path}/${model_name}_weightquant.ms --inDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/input/${model_name}.ms.bin --benchmarkDataFile=/home/workspace/mindspore_dataset/mslite/models/hiai/input_output/output/${model_name}.ms.out --accuracyThreshold=${accuracy_limit} >> "${run_x86_avx_log_file}" if [ $? = 0 ]; then run_result='x86_avx: '${model_name}'[weight quant] pass'; echo ${run_result} >> ${run_benchmark_result_file} else @@ -1637,14 +1644,16 @@ function Run_arm64() { # Run mindir weightquant converted train models: while read line; do - model_name=${line} - if [[ $model_name == \#* ]]; then + weight_quant_line_info=${line} + if [[ $weight_quant_line_info == \#* ]]; then continue fi - echo ${model_name}'_train' >> "${run_arm64_log_file}" + model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` + accuracy_limit=`echo ${weight_quant_line_info}|awk -F ' ' '{print $2}'` + echo ${model_name} >> "${run_arm64_log_file}" echo 'cd /data/local/tmp/benchmark_test' > adb_run_cmd.txt - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.weightquant.ms.out --loopCount=1' >> "${run_arm64_log_file}" - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.weightquant.ms.out --loopCount=1' >> adb_run_cmd.txt + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --loopCount=1 --accuracyThreshold='${accuracy_limit} >> "${run_arm64_log_file}" + echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/local/tmp/benchmark_test;./benchmark --modelFile='${model_name}'_weightquant.ms --inDataFile=/data/local/tmp/input_output/input/'${model_name}'.ms.bin --benchmarkDataFile=/data/local/tmp/input_output/output/'${model_name}'.ms.out --loopCount=1 --accuracyThreshold='${accuracy_limit} >> adb_run_cmd.txt adb -s ${device_id} shell < adb_run_cmd.txt >> "${run_arm64_log_file}" if [ $? = 0 ]; then run_result='arm64: '${model_name}'[weightQuant] pass'; echo ${run_result} >> ${run_benchmark_result_file} diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index ba3a8bd206..a0330d1c36 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -508,7 +508,7 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr &input_ano } paramTensor->name = input_name; - if (primitive_c->IsEnableHuffmanCode() && paramTensor->dataType == kNumberTypeInt8) { + if (primitive_c->enable_huffman_code() && paramTensor->dataType == kNumberTypeInt8) { paramTensor->enableHuffmanCode = true; } node_id_map_[input_name] = meta_graphT->allTensors.size(); diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index f33ce17cda..d8d86b611b 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -52,7 +52,6 @@ #include "tools/optimizer/graph/inputs_adjust_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" -#include "tools/converter/quantizer/huffman_encode.h" #include "tools/converter/quantizer/weight_quantizer.h" using std::string; @@ -243,24 +242,6 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla return RET_OK; } -int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, - bool enableHuffmanCode) { - if (config->quantType == schema::QuantType_WeightQuant && enableHuffmanCode) { - if (config->bitNum < 16 && config->bitNum > 8) { - MS_LOG(WARNING) << "don't support huffman encode when 8 < bitNum < 16 currently."; - return RET_OK; - } - auto huffman_encode = std::make_unique(); - auto status = huffman_encode->DoHuffmanEncode(new_graph, config->bitNum); - if (status != RET_OK) { - MS_LOG(ERROR) << "Huffman encode failed."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return RET_ERROR; - } - } - return RET_OK; -} - FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) { MS_ASSERT(nullptr != old_graph); if (config == nullptr) { @@ -315,12 +296,6 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap return nullptr; } - status = DoHuffmanEncode(config, new_graph, false); - if (status != RET_OK) { - MS_LOG(ERROR) << "Do HuffmanCode failed."; - return nullptr; - } - return new_graph; } diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 4ed86d16df..e4de7d5d3d 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -58,8 +58,6 @@ class AnfTransform { int RunTFAdjustPass(const FuncGraphPtr &old_graph, const converter::Flags *config); int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); - - int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, bool enableHuffmanCode); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index d34654f068..472c867703 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -227,7 +227,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { quantNodeOptimizer.AddPass(dTypeTransPass); quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); status = quantNodeOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { @@ -287,6 +286,15 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } + { + Optimizer quantNodeOptimizer; + quantNodeOptimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); + status = quantNodeOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; + return status; + } + } return RET_OK; } // namespace mindspore::lite } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/huffman_encode.cc b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc index b3b50c8a05..4d764a95d1 100644 --- a/mindspore/lite/tools/converter/quantizer/huffman_encode.cc +++ b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc @@ -24,120 +24,49 @@ namespace mindspore { namespace lite { -STATUS HuffmanEncode::GetParamValueLitePtr(const std::shared_ptr &input_node, ParamValueLitePtr *param_value) { - if (!input_node->isa()) { - return RET_CONTINUE; - } - auto abstract_base = input_node->abstract(); - if (abstract_base == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope(); - return RET_ERROR; - } - if (!utils::isa(abstract_base)) { - MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope(); - return RET_ERROR; - } - auto abstract_tensor = utils::cast(abstract_base); - if (abstract_tensor->element() == nullptr) { - MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope(); +STATUS HuffmanEncode::DoHuffmanEncode(const ParamValueLitePtr &weight, const std::shared_ptr &primitive_c, + void *quant_datas, const size_t &bit_num) { + if (quant_datas == nullptr) { + MS_LOG(ERROR) << "quant data is nullptr"; return RET_ERROR; } - auto tensor_type = abstract_tensor->element()->GetTypeTrack(); - MS_ASSERT(tensor_type != nullptr); - auto tensor_type_id = tensor_type->type_id(); - if (tensor_type_id != kNumberTypeInt8) { - return RET_CONTINUE; - } - auto param_node = input_node->cast(); - if (param_node == nullptr) { - MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope(); - return RET_ERROR; - } - if (!param_node->has_default()) { - MS_LOG(WARNING) << "param_node don't have default: " << input_node->fullname_with_scope(); - return RET_CONTINUE; - } - *param_value = std::static_pointer_cast(param_node->default_param()); - return RET_OK; -} - -STATUS HuffmanEncode::DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num) { - auto cnodes = func_graph->GetOrderedCnodes(); - for (auto &cnode : cnodes) { - auto primitive_c = GetValueNode>(cnode->input(0)); - if (primitive_c == nullptr) { - MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope(); - return RET_ERROR; - } - if (primitive_c->quant_type() != schema::QuantType_WeightQuant) { - continue; + auto *raw_datas = static_cast(quant_datas); + size_t elem_count = weight->tensor_shape_size(); + size_t packed_size = elem_count * bit_num; + + HuffmanPriorityQueue pq; + auto status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); + if (status != RET_OK) { + MS_LOG(ERROR) << "GetHuffmanPriorityQueue failed"; + return status; + } + status = BuildHuffmanTree(&pq); + if (status != RET_OK) { + MS_LOG(ERROR) << "BuildHuffmanTree failed"; + return status; + } + status = DoHuffmanCompress(raw_datas, elem_count); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoHuffmanCompress failed"; + return status; + } + size_t ch_size = huffman_encoded_str_.length(); + if (ch_size < packed_size) { + auto encode_data = new (std::nothrow) char[ch_size]; + if (encode_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed."; + return RET_MEMORY_FAILED; } - for (size_t i = 1; i < cnode->inputs().size(); i++) { - auto input_node = cnode->input(i); - ParamValueLitePtr param_value; - auto status = GetParamValueLitePtr(input_node, ¶m_value); - if (status == RET_CONTINUE) { - continue; - } else if (status == RET_ERROR) { - MS_LOG(ERROR) << "Get param value lite ptr failed. " << cnode->fullname_with_scope(); - return RET_ERROR; - } - size_t elem_count = param_value->tensor_shape_size(); - size_t packed_size = param_value->tensor_size(); - auto *raw_datas = static_cast(param_value->tensor_addr()); - if (raw_datas == nullptr) { - MS_LOG(ERROR) << "rawDatas is nullptr"; - return RET_ERROR; - } - if (bit_num < 8 && bit_num > 0) { - auto dst_data = new (std::nothrow) int8_t[elem_count]; - if (dst_data == nullptr) { - MS_LOG(ERROR) << "new int8_t[] failed"; - return RET_ERROR; - } - DequantUtil::UnpackUtil(raw_datas, packed_size, bit_num, dst_data); - if (memcpy_s(raw_datas, elem_count, dst_data, elem_count) != EOK) { - MS_LOG(ERROR) << "memcpy_s failed."; - return RET_MEMORY_FAILED; - } - } - HuffmanPriorityQueue pq; - status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); - if (status != RET_OK) { - MS_LOG(ERROR) << "GetHuffmanPriorityQueue failed"; - return status; - } - status = BuildHuffmanTree(&pq); - if (status != RET_OK) { - MS_LOG(ERROR) << "BuildHuffmanTree failed"; - return status; - } - status = DoHuffmanCompress(raw_datas, elem_count); - if (status != RET_OK) { - MS_LOG(ERROR) << "DoHuffmanCompress failed"; - return status; - } - size_t ch_size = huffman_encoded_str_.length(); - if (ch_size < packed_size) { - auto encode_data = new (std::nothrow) char[ch_size]; - if (encode_data == nullptr) { - MS_LOG(ERROR) << "new char[] failed."; - delete[] raw_datas; - return RET_MEMORY_FAILED; - } - delete[] raw_datas; - if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { - MS_LOG(ERROR) << "memcpy_s failed."; - delete[] encode_data; - return RET_MEMORY_FAILED; - } - param_value->SetTensorData(encode_data, ch_size); - primitive_c->SetEnableHuffmanCode(true); - } - huffman_encoded_str_.clear(); - huffman_table_.clear(); + if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed."; + delete[] encode_data; + return RET_MEMORY_FAILED; } + weight->SetTensorData(encode_data, ch_size); + primitive_c->set_enable_huffman_code(true); } + huffman_encoded_str_.clear(); + huffman_table_.clear(); return RET_SUCCESS; } diff --git a/mindspore/lite/tools/converter/quantizer/huffman_encode.h b/mindspore/lite/tools/converter/quantizer/huffman_encode.h index ff02e45263..250e0fd143 100644 --- a/mindspore/lite/tools/converter/quantizer/huffman_encode.h +++ b/mindspore/lite/tools/converter/quantizer/huffman_encode.h @@ -60,7 +60,8 @@ class HuffmanEncode { STATUS GetParamValueLitePtr(const std::shared_ptr &input_node, ParamValueLitePtr *param_value); - STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num); + STATUS DoHuffmanEncode(const ParamValueLitePtr &weight, const std::shared_ptr &primitive_c, + void *quant_datas, const size_t &bit_num); private: std::map huffman_table_; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 65a0e5ed9d..aec6f7cd26 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -36,6 +36,7 @@ #include "base/base.h" #include "ir/primitive.h" #include "abstract/dshape.h" +#include "tools/converter/quantizer/huffman_encode.h" #include "tools/converter/quantizer/bitpacking.h" #include "src/lite_session.h" #include "tools/converter/graphdef_transform.h" @@ -92,6 +93,7 @@ class QuantStrategy { constexpr float delta = 0.1; constexpr float ratio = 10.0; constexpr int percent = 10; +constexpr int quant_param_size = 32 * 8; STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits); @@ -158,163 +160,159 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan } template -STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr &primitive_c, QuantType quantType, - int quant_max, int quant_min, size_t bitNum, bool per_channel, int index = 1, bool k_means = false) { - MS_ASSERT(weight != nullptr); - MS_ASSERT(primitive_c != nullptr); +STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant_type, + std::vector *quant_params, const int &quant_max, const int &quant_min, + const size_t &bit_num, const bool &k_means, std::vector *quant_datas, + std::vector *dequant_datas) { auto dims = weight->tensor_shape(); - if (per_channel) { - if (dims.size() <= 1) { - MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel"; - per_channel = false; - } - } - - std::vector quant_params; size_t elem_count = weight->tensor_shape_size(); auto *raw_datas = static_cast(weight->tensor_addr()); - if (raw_datas == nullptr) { - MS_LOG(ERROR) << "rawDatas is nullptr"; + auto channels = dims[0]; + if (channels == 0) { + MS_LOG(ERROR) << "channels is zero"; return RET_ERROR; } - std::vector quant_datas(elem_count); - std::vector dequant_datas(elem_count); - if (per_channel) { - // notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC - // channel at first - auto channels = dims[0]; - if (channels == 0) { - MS_LOG(ERROR) << "channels is zero"; - return RET_ERROR; + size_t one_filter_size = elem_count / channels; + bool do_quant = quant_param_size / (sizeof(float) * 8 - bit_num) < one_filter_size; + if (!do_quant && quant_type == QuantType_WeightQuant) { + MS_LOG(INFO) << "too few elements in a filter, no need to quantize. " << one_filter_size; + return RET_CONTINUE; + } + for (int i = 0; i < channels; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + // find min and max + for (size_t j = 0; j < one_filter_size; j++) { + auto index = j + i * one_filter_size; + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; + } + min = std::min(min, raw_datas[index]); + max = std::max(max, raw_datas[index]); } - size_t one_filter_size = elem_count / channels; - - for (int i = 0; i < channels; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - // find min and max - for (size_t j = 0; j < one_filter_size; j++) { - auto index = j + i * one_filter_size; - if (index >= elem_count) { - MS_LOG(ERROR) << "over flow!"; - return RET_ERROR; - } - min = std::min(min, raw_datas[index]); - max = std::max(max, raw_datas[index]); + schema::QuantParamT quant_param; + STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalQuantizationParams failed" << status; + return status; + } + // do quantization + double average_dequant = 0; + double average_raw = 0; + for (uint32_t j = 0; j < one_filter_size; j++) { + auto index = j + i * one_filter_size; + if (index >= elem_count) { + MS_LOG(ERROR) << "over flow!"; + return RET_ERROR; } - schema::QuantParamT quant_param; - STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); - if (status != RET_OK) { - MS_LOG(ERROR) << "CalQuantizationParams failed" << status; - return status; + float raw_data = raw_datas[index]; + auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); + (*quant_datas)[index] = quant_data; + + if (quant_type == QuantType_WeightQuant) { + float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint); + (*dequant_datas)[index] = dequant_data; + average_dequant += dequant_data; + average_raw += raw_data; } - // do quantization - double average_dequant = 0; - double average_raw = 0; + } + if (quant_type == QuantType_WeightQuant && !k_means) { + // mean + average_dequant = average_dequant / one_filter_size; + average_raw = average_raw / one_filter_size; + // std + double variance_dequant = 0; + double variance_raw = 0; for (uint32_t j = 0; j < one_filter_size; j++) { auto index = j + i * one_filter_size; if (index >= elem_count) { MS_LOG(ERROR) << "over flow!"; return RET_ERROR; } - float raw_data = raw_datas[index]; - auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); - quant_datas[index] = quant_data; - - if (quantType == QuantType_WeightQuant) { - float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint); - dequant_datas[index] = dequant_data; - average_dequant += dequant_data; - average_raw += raw_data; - } + variance_dequant += std::pow((*dequant_datas)[index] - average_dequant, 2); + variance_raw += std::pow(raw_datas[index] - average_raw, 2); } - if (quantType == QuantType_WeightQuant && !k_means) { - // mean - average_dequant = average_dequant / one_filter_size; - average_raw = average_raw / one_filter_size; - // std - double variance_dequant = 0; - double variance_raw = 0; - for (uint32_t j = 0; j < one_filter_size; j++) { - auto index = j + i * one_filter_size; - if (index >= elem_count) { - MS_LOG(ERROR) << "over flow!"; - return RET_ERROR; - } - variance_dequant += std::pow(dequant_datas[index] - average_dequant, 2); - variance_raw += std::pow(raw_datas[index] - average_raw, 2); - } - variance_dequant = std::sqrt(variance_dequant / one_filter_size); - variance_raw = std::sqrt(variance_raw / one_filter_size); - quant_param.varCorr = 1; - if (variance_raw != 0 && variance_dequant != 0) { - auto temp_var_corr = variance_raw / variance_dequant; - if (temp_var_corr > 0 && temp_var_corr < 10) { - quant_param.varCorr = temp_var_corr; - } else { - MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr; - } + variance_dequant = std::sqrt(variance_dequant / one_filter_size); + variance_raw = std::sqrt(variance_raw / one_filter_size); + quant_param.varCorr = 1; + if (variance_raw != 0 && variance_dequant != 0) { + auto temp_var_corr = variance_raw / variance_dequant; + if (temp_var_corr > 0 && temp_var_corr < 10) { + quant_param.varCorr = temp_var_corr; + } else { + MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr; } - quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr; } - quant_params.emplace_back(quant_param); + quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr; } - auto status = UpdateTensorDataAndSize(weight, quant_datas.data(), quant_datas.size() * sizeof(T)); + quant_params->emplace_back(quant_param); + } + auto status = UpdateTensorDataAndSize(weight, quant_datas->data(), quant_datas->size() * sizeof(T)); + if (status != RET_OK) { + MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; + return RET_ERROR; + } + return RET_OK; +} + +template +STATUS DoPerLayerQuant(const ParamValueLitePtr &weight, const QuantType &quant_type, + std::vector *quant_params, const int &quant_max, const int &quant_min, + const size_t &bit_num, const bool &k_means, std::vector *quant_datas) { + auto dims = weight->tensor_shape(); + size_t elem_count = weight->tensor_shape_size(); + auto *raw_datas = static_cast(weight->tensor_addr()); + float min = FLT_MAX; + float max = -FLT_MIN; + for (uint32_t i = 0; i < elem_count; i++) { + // find max min + min = std::min(min, raw_datas[i]); + max = std::max(max, raw_datas[i]); + } + + schema::QuantParamT quant_param; + if (!k_means) { + STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num); if (status != RET_OK) { - MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; - return RET_ERROR; - } - } else { - // per layer - float min = FLT_MAX; - float max = -FLT_MIN; - for (uint32_t i = 0; i < elem_count; i++) { - // find max min - min = std::min(min, raw_datas[i]); - max = std::max(max, raw_datas[i]); + MS_LOG(ERROR) << "CalQuantizationParams failed" << status; + return status; } - - schema::QuantParamT quant_param; + } + quant_params->emplace_back(quant_param); + // update data and datatype + for (uint32_t i = 0; i < elem_count; i++) { + float raw_data = raw_datas[i]; if (!k_means) { - STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum); - if (status != RET_OK) { - MS_LOG(ERROR) << "CalQuantizationParams failed" << status; - return status; - } - } - quant_params.emplace_back(quant_param); - // update data and datatype - for (uint32_t i = 0; i < elem_count; i++) { - float raw_data = raw_datas[i]; - if (!k_means) { - auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); - quant_datas[i] = quant_data; - } - } - auto status = UpdateTensorDataAndSize(weight, quant_datas.data(), quant_datas.size() * sizeof(T)); - if (status != RET_OK) { - MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; - return RET_ERROR; + auto quant_data = QuantizeData(raw_data, quant_param, quant_max, quant_min); + (*quant_datas)[i] = quant_data; } } - - // do bit pack - if (bitNum != 8 && bitNum != 16) { + auto status = UpdateTensorDataAndSize(weight, quant_datas->data(), quant_datas->size() * sizeof(T)); + if (status != RET_OK) { + MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; + return RET_ERROR; + } + return RET_OK; +} +template +STATUS DoBitPack(const ParamValueLitePtr &weight, const size_t &bit_num, const std::vector &quant_datas) { + if (bit_num != 8 && bit_num != 16) { std::vector data{}; for (size_t i = 0; i < quant_datas.size(); ++i) { data.emplace_back((static_cast(quant_datas[i]))); } - if (bitNum > 0 && bitNum < 8) { + if (bit_num > 0 && bit_num < 8) { std::vector pack_data{}; - BitPack::BitPacking(bitNum, data, &pack_data); + BitPack::BitPacking(bit_num, data, &pack_data); auto status = UpdateTensorDataAndSize(weight, pack_data.data(), pack_data.size() * sizeof(uint8_t)); if (status != RET_OK) { MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; return RET_ERROR; } - } else if (bitNum > 8 && bitNum < 16) { + } else if (bit_num > 8 && bit_num < 16) { std::vector pack_data{}; - BitPack::BitPacking(bitNum, data, &pack_data); + BitPack::BitPacking(bit_num, data, &pack_data); auto status = UpdateTensorDataAndSize(weight, pack_data.data(), pack_data.size() * sizeof(uint16_t)); if (status != RET_OK) { MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; @@ -322,17 +320,79 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr +STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr &primitive_c, + QuantType quant_type, int quant_max, int quant_min, size_t bit_num, bool per_channel, int index = 1, + bool k_means = false) { + MS_ASSERT(weight != nullptr); + MS_ASSERT(primitive_c != nullptr); + auto dims = weight->tensor_shape(); + if (per_channel) { + if (dims.size() <= 1) { + MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel"; + per_channel = false; + } + } + + std::vector quant_params; + size_t elem_count = weight->tensor_shape_size(); + auto *raw_datas = static_cast(weight->tensor_addr()); + if (raw_datas == nullptr) { + MS_LOG(ERROR) << "rawDatas is nullptr"; + return RET_ERROR; + } + + std::vector quant_datas(elem_count); + std::vector dequant_datas(elem_count); + int ret = RET_OK; + if (per_channel) { + // notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC + // channel at first + ret = DoPerChannelQuant(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas, + &dequant_datas); + if (ret == RET_CONTINUE) { + return ret; + } else if (ret != RET_OK) { + MS_LOG(ERROR) << "Do per channel quant failed."; + return ret; + } + } else { + ret = DoPerLayerQuant(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Do per layer quant failed."; + return ret; + } + } + +#ifdef HUFFMAN_ENCODE + auto huffman_encode = std::make_unique(); + ret = huffman_encode->DoHuffmanEncode(weight, primitive_c, quant_datas.data(), bit_num); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Do huffman encode failed."; + return ret; + } +#else + // do bit pack + ret = DoBitPack(weight, bit_num, quant_datas); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Do bit pack failed."; + return ret; + } +#endif if (quant_params.empty()) { MS_LOG(ERROR) << "quant_params empty"; return RET_ERROR; } - if (quantType == QuantType_PostTraining) { + if (quant_type == QuantType_PostTraining) { primitive_c->AddInputQuantParam(quant_params); } else { primitive_c->set_input_quant_param(index, quant_params); } - return RET_OK; + return ret; } // utils diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 993250587c..90b1a88577 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -109,7 +109,9 @@ STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) { status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); } - if (status != RET_OK) { + if (status == RET_CONTINUE) { + return RET_OK; + } else if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } @@ -173,7 +175,9 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); } - if (status != RET_OK) { + if (status == RET_CONTINUE) { + return RET_OK; + } else if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } @@ -246,7 +250,9 @@ STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) { status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0); } - if (status != RET_OK) { + if (status == RET_CONTINUE) { + return RET_OK; + } else if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } @@ -286,7 +292,9 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const st status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, index - 1); } - if (status != RET_OK) { + if (status == RET_CONTINUE) { + return RET_OK; + } else if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } @@ -503,7 +511,9 @@ STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr ¶m MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; return RET_ERROR; } - if (status != RET_OK) { + if (status == RET_CONTINUE) { + return RET_OK; + } else if (status != RET_OK) { MS_LOG(ERROR) << "quant filter failed."; return RET_ERROR; }