From 18c6ac9988b80b8ad1b148074b0ee25e69147205 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Tue, 11 Aug 2020 15:10:21 +0800 Subject: [PATCH] add quant aware compile success --- .../src/common/anf_exporter/anf_exporter.cc | 2 +- .../anf_importer/import_from_meta_graphT.cc | 21 + mindspore/lite/tools/common/graph_util.cc | 26 +- mindspore/lite/tools/common/graph_util.h | 2 +- mindspore/lite/tools/common/tensor_util.cc | 32 +- mindspore/lite/tools/common/tensor_util.h | 5 + mindspore/lite/tools/converter/CMakeLists.txt | 1 + mindspore/lite/tools/converter/converter.cc | 27 +- .../lite/tools/converter/converter_flags.cc | 12 +- .../lite/tools/converter/converter_flags.h | 3 +- .../tools/converter/graphdef_transform.cc | 117 +++- .../lite/tools/converter/graphdef_transform.h | 6 +- .../fusion/matmul_biasadd_fusion_pass.h | 2 +- .../legacy_optimizer/graph/CMakeLists.txt | 1 + .../graph/dtype_trans_pass.cc | 235 +++++++ .../legacy_optimizer/graph/dtype_trans_pass.h | 81 +++ .../node/weight_format_pass.cc | 81 ++- mindspore/lite/tools/converter/model_parser.h | 3 +- .../parser/caffe/caffe_model_parser.cc | 5 +- .../parser/caffe/caffe_model_parser.h | 3 +- .../converter/parser/onnx/onnx_model_parser.h | 3 +- .../parser/tflite/tflite_model_parser.cc | 137 ++-- .../parser/tflite/tflite_model_parser.h | 19 +- .../tools/converter/quantizer/CMakeLists.txt | 2 + .../converter/quantizer/aware_quantizer.cc | 594 ++++++++++++++++++ .../converter/quantizer/aware_quantizer.h | 65 ++ .../converter/quantizer/calc_quant_param.cc | 504 +++++++++++++++ .../converter/quantizer/calc_quant_param.h | 69 ++ .../converter/quantizer/quantize_util.cc | 399 +++++++----- .../tools/converter/quantizer/quantize_util.h | 35 ++ .../tools/converter/quantizer/quantizer.cc | 19 +- .../tools/converter/quantizer/quantizer.h | 57 +- 32 files changed, 2211 insertions(+), 357 deletions(-) mode change 100755 => 100644 mindspore/lite/tools/common/graph_util.cc create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc create mode 100644 mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h create mode 100644 mindspore/lite/tools/converter/quantizer/aware_quantizer.cc create mode 100644 mindspore/lite/tools/converter/quantizer/aware_quantizer.h create mode 100644 mindspore/lite/tools/converter/quantizer/calc_quant_param.cc create mode 100644 mindspore/lite/tools/converter/quantizer/calc_quant_param.h diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index 81cdbe9429..31bd1b3222 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { // add quant param node->quantType = primitiveT_value->GetQuantType(); - if (node->quantType == schema::QuantType_PostTraining) { + if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) { MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; // activation auto input_quant_params = primitiveT_value->GetInputQuantParams(); diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc index 703e0b0715..0b47a8e636 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc @@ -60,6 +60,17 @@ void AnfImporterFromMetaGraphT::ConverterConstTensor() { param_value->set_tensor_addr(tensor_data); param_value->set_tensor_size(size); } + if (tensor->quantParams.size() > 0) { + std::unique_ptr quantParam = std::make_unique(); + quantParam->scale = tensor->quantParams[0]->scale; + quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; + quantParam->min = tensor->quantParams[0]->min; + quantParam->max = tensor->quantParams[0]->max; + quantParam->narrowRange = tensor->quantParams[0]->narrowRange; + quantParam->numBits = tensor->quantParams[0]->numBits; + quantParam->inited = tensor->quantParams[0]->inited; + param_value->set_quant_param(quantParam); + } parameter->set_default_param(param_value); AddNode(i, parameter); } @@ -77,6 +88,16 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { flag = true; } auto primTValue = std::make_shared(cNode->primitive.release()); + // add quant parameter + if (cNode->quantType == schema::QuantType_AwareTrainning || cNode->quantType == schema::QuantType_PostTraining) { + primTValue->SetQuantType(cNode->quantType); + for (int index : cNode->inputIndex) { + primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); + } + for (int index : cNode->outputIndex) { + primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); + } + } cNode->primitive = nullptr; auto value_node = NewValueNode(primTValue); diff --git a/mindspore/lite/tools/common/graph_util.cc b/mindspore/lite/tools/common/graph_util.cc old mode 100755 new mode 100644 index 1b84029174..90973a246a --- a/mindspore/lite/tools/common/graph_util.cc +++ b/mindspore/lite/tools/common/graph_util.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace lite { OpDefCopyer GetSimpleOpCopyer() { - return [](std::unique_ptr &inCNode) -> std::unique_ptr { + return [](CNodeT *inCNode) -> std::unique_ptr { std::unique_ptr newCNode(new CNodeT); newCNode->name = inCNode->name; @@ -421,9 +421,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si } preTensor->refCount = 0; preTensor->data.clear(); + if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { + preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; + toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; + } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1; - auto toAddNode = opDefCopyer(toAddNodeIn); + auto toAddNode = opDefCopyer(toAddNodeIn.get()); if (toAddNode == nullptr) { MS_LOG(ERROR) << "copy toAddNodeIn failed"; *errorCode = RET_NULL_PTR; @@ -456,9 +460,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si // MS_LOG(ERROR)("Copy TensorT failed"); return graphT->nodes.end(); } + if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { + preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; + toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; + } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1; - auto toAddNode = opDefCopyer(toAddNodeIn); + auto toAddNode = opDefCopyer(toAddNodeIn.get()); if (toAddNode == nullptr) { // MS_LOG(ERROR)("copy toAddNodeIn failed"); *errorCode = RET_NULL_PTR; @@ -505,9 +513,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz *errorCode = RET_NULL_PTR; return graphT->nodes.end(); } + if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { + postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; + toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; + } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1; - auto toAddNode = opDefCopyer(toAddNodeIn); + auto toAddNode = opDefCopyer(toAddNodeIn.get()); if (toAddNode == nullptr) { // MS_LOG(ERROR)("copy toAddNodeIn failed"); *errorCode = RET_NULL_PTR; @@ -540,9 +552,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz *errorCode = RET_NULL_PTR; return graphT->nodes.end(); } + if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) { + postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT; + toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT; + } graphT->allTensors.emplace_back(std::move(toAddTensor)); size_t toAddTensorIdx = graphT->allTensors.size() - 1; - auto toAddNode = opDefCopyer(toAddNodeIn); + auto toAddNode = opDefCopyer(toAddNodeIn.get()); if (toAddNode == nullptr) { // MS_LOG(ERROR)("copy toAddNodeIn failed"); *errorCode = RET_NULL_PTR; diff --git a/mindspore/lite/tools/common/graph_util.h b/mindspore/lite/tools/common/graph_util.h index 818c53502b..f3b4b97e3a 100644 --- a/mindspore/lite/tools/common/graph_util.h +++ b/mindspore/lite/tools/common/graph_util.h @@ -36,7 +36,7 @@ enum InsertPlace { kBefore, kAfter }; using NodeIter = std::vector>::iterator; -using OpDefCopyer = std::function(std::unique_ptr &)>; +using OpDefCopyer = std::function (schema::CNodeT *)>; OpDefCopyer GetSimpleOpCopyer(); diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc index e41d5efd2c..a27de70197 100644 --- a/mindspore/lite/tools/common/tensor_util.cc +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -19,8 +19,29 @@ #include "tools/common/tensor_util.h" #include "tools/common/graph_util.h" -namespace mindspore { -namespace lite { +namespace mindspore::lite { +std::unique_ptr GetTensorQuantParam(const std::unique_ptr &tensor) { + MS_ASSERT(tensor != nullptr); + auto &quantParams = tensor->quantParams; + if (!quantParams.empty()) { + return std::move(CopyQuantParamT(quantParams.front())); + } else { + return nullptr; + } +} +std::unique_ptr CopyQuantParamT(const std::unique_ptr &srcQuantParam) { + MS_ASSERT(srcQuantParam != nullptr); + std::unique_ptr dstQuantParam = std::make_unique(); + dstQuantParam->inited = srcQuantParam->inited; + dstQuantParam->scale = srcQuantParam->scale; + dstQuantParam->zeroPoint = srcQuantParam->zeroPoint; + dstQuantParam->min = srcQuantParam->min; + dstQuantParam->max = srcQuantParam->max; + dstQuantParam->narrowRange = srcQuantParam->narrowRange; + dstQuantParam->numBits = srcQuantParam->numBits; + return std::move(dstQuantParam); +} + std::unique_ptr CopyQuantParamArrayT(const std::unique_ptr &srcQuantParamArray) { MS_ASSERT(srcQuantParamArray != nullptr); auto dstQuantParamArrayT = std::unique_ptr(new (std::nothrow) QuantParamT()); @@ -164,6 +185,9 @@ std::unique_ptr CopyTensorDefT(const std::unique_ptr &oldTenso newTensor->refCount = oldTensor->refCount; newTensor->nodeType = oldTensor->nodeType; newTensor->data = oldTensor->data; + if (!oldTensor->quantParams.empty()) { + newTensor->quantParams.emplace_back(std::move(GetTensorQuantParam(oldTensor))); + } return std::move(newTensor); } @@ -186,6 +210,4 @@ size_t GetShapeSize(const std::vector &shape) { } return shapeSize; } -} // namespace lite -} // namespace mindspore - +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/common/tensor_util.h b/mindspore/lite/tools/common/tensor_util.h index 93cb3520a1..ef3b58a842 100644 --- a/mindspore/lite/tools/common/tensor_util.h +++ b/mindspore/lite/tools/common/tensor_util.h @@ -38,6 +38,9 @@ using schema::FusedBatchNormT; using schema::Format_NCHW; using schema::Format_NHWC; using STATUS = int; + +std::unique_ptr GetTensorQuantParam(const std::unique_ptr &tensor); + size_t GetElementSize(const TensorT &tensor); size_t GetElementSize(const TypeId &dataType); @@ -50,6 +53,8 @@ std::unique_ptr CopyTensorDefT(const std::unique_ptr &); size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx); +std::unique_ptr CopyQuantParamT(const std::unique_ptr &srcQuantParam); + std::unique_ptr \ CopyQuantParamArrayT(const std::unique_ptr &srcQuantParamArray); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 9119138545..645b40d8a3 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -101,6 +101,7 @@ target_link_libraries(converter_lite PRIVATE node_mid graph_pass_mid fusion_mid + quantizer_mid protobuf quantizer_mid pthread diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 61d1f20f1f..07dd48268e 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -77,7 +77,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { MS_ASSERT(nullptr != modelParser); const std::string modelFile = flag->modelFile; const std::string weightFile = flag->weightFile; - auto meta_graph = modelParser->Parse(modelFile, weightFile); + auto meta_graph = modelParser->Parse(modelFile, weightFile, flag->quantType); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Parse to metaGraph return nullptr"; return nullptr; @@ -118,6 +118,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { // transform transform->SetGraphDef(meta_graph); + transform->CreateQuantizer(flag); auto status = transform->Transform(*flag); if (status != 0) { MS_LOG(ERROR) << "FBTransform model failed " << status; @@ -125,6 +126,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { } return meta_graph; } + void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { auto type = flags->quantType; switch (type) { @@ -132,17 +134,18 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags * // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); break; } - case mindspore::schema::QuantType_WeightQuant: { - MS_LOG(INFO) << "create WeightQuantizer!"; - mQuantizer.reset( - new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, flags->bitNum)); - break; - } - case mindspore::schema::QuantType_PostTraining: { - MS_LOG(INFO) << "create PostTrainningQuantizer!"; - mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8)); - break; - } + // case mindspore::schema::QuantType_WeightQuant: { + // MS_LOG(INFO) << "create WeightQuantizer!"; + // mQuantizer.reset( + // new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, + // flags->bitNum)); + // break; + // } + // case mindspore::schema::QuantType_PostTraining: { + // MS_LOG(INFO) << "create PostTrainningQuantizer!"; + // mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8)); + // break; + // } case mindspore::schema::QuantType_QUANT_NONE: MS_LOG(INFO) << "Not do quantization for model!"; break; diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 07347663f9..f91db803a0 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -14,8 +14,12 @@ * limitations under the License. */ -#include + #include "tools/converter/converter_flags.h" +#include +#include +#include "ir/dtype/type_id.h" + namespace mindspore { namespace lite { @@ -70,9 +74,11 @@ int Flags::Init(int argc, const char **argv) { return 1; } if (this->inputInferenceTypeIn == "FLOAT") { - this->inputInferenceType = 0; + this->inputInferenceType = TypeId::kNumberTypeFloat; } else if (this->inputInferenceTypeIn == "UINT8") { - this->inputInferenceType = 1; + this->inputInferenceType = TypeId::kNumberTypeUInt8; + } else if (this->inputInferenceTypeIn == "INT8") { + this->inputInferenceType = TypeId::kNumberTypeInt8; } else { std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); return 1; diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 0594663a9b..9ccfed6ceb 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -19,6 +19,7 @@ #include #include "tools/common/flag_parser.h" +#include "ir/dtype/type_id.h" #include "schema/inner/model_generated.h" namespace mindspore { @@ -66,7 +67,7 @@ class Flags : public virtual mindspore::lite::FlagParser { // used for parse aware trainning std::string inputInferenceTypeIn; // mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT; - int inputInferenceType = 0; + TypeId inputInferenceType = TypeId::kNumberTypeFloat; std::string stdDev; std::string mean; // used for post-trainning-weight diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index d0240ed9f3..9dc6cb3e56 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -16,11 +16,13 @@ #include "tools/converter/graphdef_transform.h" #include +#include #include #include "schema/model_generated.h" #include "utils/log_adapter.h" #include "src/common/op_utils.h" #include "tools/converter/converter_flags.h" +#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" #include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h" @@ -28,7 +30,7 @@ #include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h" // #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" -// #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" +#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" // #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h" // // #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h" @@ -52,18 +54,45 @@ #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" - +#include "tools/converter/quantizer/aware_quantizer.h" #include "tools/converter/converter.h" using std::string; -namespace mindspore { -namespace lite { +namespace mindspore::lite { GraphDefTransform::GraphDefTransform() = default; GraphDefTransform::~GraphDefTransform() = default; void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; } +void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { + auto type = flags->quantType; + switch (type) { + case QuantType::QuantType_AwareTrainning: { + MS_LOG(INFO) << "create AwareTrainningQuantizer!"; + fbQuantizer = + std::make_unique(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); + break; + } + // case QuantType::QuantType_WeightQuant: { + // MS_LOGI("create WeightQuantizer!"); + // mQuantizer.reset(new WeightQuantizer(graphDefT, flags->quantSize)); + // break; + // } + // case QuantType_PostTraining: { + // MS_LOGI("create PostTrainningQuantizer!"); + // mQuantizer.reset(new PostTrainingQuantizer(graphDefT, flags->configFile)); + // break; + // } + // case QuantType::QuantType_QUANT_NONE: + // MS_LOGD("Not do quantization for model!"); + // break; + default: + // MS_LOGI("will support quantizer type %s in the future!", flags->quantTypeIn.c_str()); + break; + } +} + int GraphDefTransform::Transform(const converter::Flags &ctx) { STATUS status; // // constant folding @@ -133,6 +162,53 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } + + { + Optimizer unusedOpRemoveOptimizer; + unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); + unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); + status = unusedOpRemoveOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; + return status; + } + } + // topological sorting + { + Optimizer topologicalOptimizer; + topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + status = topologicalOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; + return status; + } + } + + // generate and infer quant parameters + { + if (mQuantizer != nullptr) { + Optimizer topologicalOptimizer; + topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); + status = topologicalOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed"; + return status; + } + if (!(this->graphDefT->fmkType == converter::FmkType_TF && + this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) { + status = mQuantizer->GenerateQuantParam(); + if (status != RET_OK) { + MS_LOG(ERROR) << "GenerateQuantParam failed"; + return status; + } + status = mQuantizer->DetermineNodeQuantType(); + if (status != RET_OK) { + MS_LOG(ERROR) << "DetermineNodeQuant failed"; + } + } + } + } + // format transform if (ctx.formatTrans) { Optimizer formatTransOptimizer; @@ -156,13 +232,30 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - { - Optimizer unusedOpRemoveOptimizer; - unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); - unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass()); - status = unusedOpRemoveOptimizer.Run(graphDefT); + // do quantization + if (fbQuantizer != nullptr) { + status = fbQuantizer->DoQuantize(); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoQuantize failed!"; + return status; + } + } + + // insert quantNode and deQuantNode + if (ctx.quantType == QuantType_AwareTrainning) { + Optimizer quantNodeOptimizer; + auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); + if (dTypeTransPass == nullptr) { + MS_LOG(ERROR) << "new dTypeTransPass failed"; + return RET_ERROR; + } + dTypeTransPass->SetInputDataDType(ctx.inputInferenceType); + quantNodeOptimizer.AddPass(dTypeTransPass); + quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); + quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); + status = quantNodeOptimizer.Run(graphDefT); if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed"; + MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed"; return status; } } @@ -178,6 +271,4 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } return RET_OK; } -} // namespace lite -} // namespace mindspore - +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/graphdef_transform.h b/mindspore/lite/tools/converter/graphdef_transform.h index b50579ac99..0251b7d15a 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.h +++ b/mindspore/lite/tools/converter/graphdef_transform.h @@ -17,8 +17,9 @@ #ifndef MS_GRAPHDEF_TRANSFORM_H #define MS_GRAPHDEF_TRANSFORM_H +#include #include "tools/converter/optimizer.h" -// #include "quantizer/quantizer.h" +#include "tools/converter/quantizer/quantizer.h" #include "schema/inner/model_generated.h" #include "tools/common/storage.h" #include "tools/converter/converter_flags.h" @@ -42,7 +43,8 @@ class GraphDefTransform { schema::MetaGraphT *graphDefT = nullptr; Optimizer *optimizer = nullptr; - // std::unique_ptr mQuantizer; + std::unique_ptr mQuantizer; + std::unique_ptr fbQuantizer; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h index cc8ad536d0..3f3a42df84 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h @@ -53,7 +53,7 @@ class MatMulBiasAddFusionPass : public FusionPass { bool transB = false; size_t id = 0; - OpDefCopyer TransposeOpCopyer = [](const std::unique_ptr &inOpDef) -> std::unique_ptr { + OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr { std::unique_ptr newOpDef(new (std::nothrow) CNodeT); if (newOpDef == nullptr) { MS_LOG(ERROR) << "new OpDefT failed"; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index e5d2ceac19..c3c03af6a1 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(graph_pass_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc new file mode 100644 index 0000000000..ec9d979e32 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" +#include +#include "tools/common/converter_op_utils.h" +#include "tools/common/node_util.h" +#include "src/common/common.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +#define kMinInputNum 1 +#define kOutputNum 1 + +STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + + auto status = DoModelInputDTypeTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoModelInputDTypeTrans error: " << status; + return status; + } + + status = DoModelOutputDTypeTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; + return status; + } + + status = DoNodeInoutDTypeTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; + return status; + } + return RET_OK; +} + +STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + // modify inputTensor first + auto &graphInIdxes = graph->inputIndex; + for (auto graphInIdx : graphInIdxes) { + MS_ASSERT(graph->allTensors.size() > graphInIdx); + auto &graphInTensor = graph->allTensors.at(graphInIdx); + graphInTensor->dataType = TypeId::kNumberTypeUInt8; + } + + if (this->inputDataDType == TypeId::kNumberTypeInt8) { + return RET_OK; + } + if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) { + MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType; + return RET_ERROR; + } + // insert fp2int8 node + for (auto graphInIdx : graphInIdxes) { + MS_ASSERT(graphInIdx < graph->allTensors.size()); + auto &tensor = graph->allTensors.at(graphInIdx); + if (tensor->dims.size() != kNHWCDimNumber) { + continue; + } + + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + auto nodeName = node->name; + for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { + if (node->inputIndex.at(inputIndexIdx) == graphInIdx) { + STATUS status = RET_OK; + + // insert dtype cast node between input tensor and input node + if (inputDataDType == TypeId::kNumberTypeFloat) { + iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kFP32ToInt8, &status); + } else { + iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kUInt8ToInt8, &status); + } + + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertDTypeTransNode before " << nodeName.c_str() << " failed"; + return status; + } + } + } + } + } + return RET_OK; +} + +STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + if (inputDataDType == TypeId::kNumberTypeInt8) { + return RET_OK; + } + MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat); + auto &graphOutIdxes = graph->outputIndex; + for (auto graphOutIdx : graphOutIdxes) { + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + auto nodeName = node->name; + MS_ASSERT(node != nullptr); + for (size_t outputIndexIdx = 0; outputIndexIdx < node->outputIndex.size(); outputIndexIdx++) { + if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) { + // insert transNode + STATUS status = RET_OK; + if (inputDataDType == TypeId::kNumberTypeFloat) { + iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status); + } else { + iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status); + } + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed"; + return status; + } + break; + } + } + } + } + return RET_OK; +} + +STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + // insert transNode before and after existNode + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) { + continue; + } + auto &node = *iter; + if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { + continue; + } + bool needInsertPost = true; + if (GetCNodeTType(**iter) == PrimitiveType_Shape) { + needInsertPost = false; + } + auto nodeName = node->name; + if (node->inputIndex.size() < kMinInputNum) { + MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; + return RET_ERROR; + } + STATUS status; + // insert pre + for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { + MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); + auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); + auto &graphInIdxes = graph->inputIndex; + if (preTensor->nodeType == NodeType_ValueNode && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { + continue; + } + iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; + return RET_ERROR; + } + } + + if (needInsertPost) { + for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { + iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; + return RET_ERROR; + } + } + } + (*iter)->quantType = QuantType_QUANT_NONE; + } + + return RET_OK; +} + +NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, + size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { + MS_ASSERT((*existNodeIter) != nullptr); + auto existNodeName = (*existNodeIter)->name; + std::string tileName; + if (place == kBefore) { + tileName = existNodeName + "_pre"; + } else { + tileName = existNodeName + "_post"; + } + auto transNode = std::unique_ptr(new (std::nothrow) CNodeT); + if (transNode == nullptr) { + MS_LOG(ERROR) << "new TransNode failed"; + *errorCode = RET_ERROR; + return graph->nodes.end(); + } + auto quantDTypeCastParam = new (std::nothrow) QuantDTypeCastT; + if (quantDTypeCastParam == nullptr) { + MS_LOG(ERROR) << "new quantDTypeCastParam failed"; + *errorCode = RET_ERROR; + return graph->nodes.end(); + } + transNode->primitive = std::make_unique(); + transNode->primitive->value.value = quantDTypeCastParam; + transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; + transNode->quantType = QuantType_AwareTrainning; + if (nodeType == kInt8ToFP32) { + quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; + quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32; + transNode->name = "int8toft32_" + tileName + std::to_string(id++); + } else if (nodeType == kFP32ToInt8) { + quantDTypeCastParam->srcT = TypeId::kNumberTypeFloat32; + quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8; + transNode->name = "ft32toint8_" + tileName + std::to_string(id++); + } else if (nodeType == kUInt8ToInt8) { + quantDTypeCastParam->srcT = TypeId::kNumberTypeUInt8; + quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8; + transNode->name = "uint8toint8_" + tileName + std::to_string(id++); + } else if (nodeType == kInt8ToUInt8) { + quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; + quantDTypeCastParam->dstT = TypeId::kNumberTypeUInt8; + transNode->name = "int8touint8_" + tileName + std::to_string(id++); + } + transNode->primitive->value.value = quantDTypeCastParam; + return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, castOpCopyer); +} + +void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h new file mode 100644 index 0000000000..1c1c0a7284 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H +#define MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H + +#include +#include +#include "tools/converter/optimizer.h" +#include "tools/common/graph_util.h" +#include "tools/converter/converter_flags.h" +#include "tools/common/tensor_util.h" + +namespace mindspore { +namespace lite { +enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 }; + +class DTypeTransPass : public GraphPass { + public: + DTypeTransPass() : id(0) {} + + ~DTypeTransPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; + + void SetInputDataDType(TypeId dataType); + + private: + STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); + + STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph); + + STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); + + NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, + DTypeTransNodeType nodeType, STATUS *errorCode); + + private: + size_t id; + TypeId inputDataDType = TypeId::kNumberTypeFloat; + + OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr { + std::unique_ptr newCNode(new (std::nothrow) schema::CNodeT); + if (newCNode == nullptr) { + MS_LOG(ERROR) << "new CNodeT failed"; + return nullptr; + } + newCNode->name = inCNode->name; + newCNode->quantType = inCNode->quantType; + newCNode->primitive = std::make_unique(); + newCNode->primitive->value.type = inCNode->primitive->value.type; + + auto oldQuantDTypeCastParam = inCNode->primitive->value.AsQuantDTypeCast(); + auto QuantDTypeCastParam = new (std::nothrow) QuantDTypeCastT; + if (QuantDTypeCastParam == nullptr) { + MS_LOG(ERROR) << "new QuantDTypeCast failed"; + return nullptr; + } + QuantDTypeCastParam->srcT = oldQuantDTypeCastParam->srcT; + QuantDTypeCastParam->dstT = oldQuantDTypeCastParam->dstT; + newCNode->primitive->value.value = QuantDTypeCastParam; + return std::move(newCNode); + }; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index 06f6d39ca0..b0eb1d9ada 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -209,6 +209,9 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { return 0; } +// inference needed filterFormat: +// conv deconv depth dedepth +// uint8 KHWC KHWC KHWC KHWC int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { MS_ASSERT(graphNode != nullptr); auto &subGraph = graphNode->subGraph; @@ -227,7 +230,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { auto &weightTensor = subGraph->allTensors[weightIndex]; MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT STATUS status = RET_OK; - if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK + if (opType == schema::PrimitiveType_Conv2D) { // weight should be KHWC if (weightTensor->format == schema::Format_KCHW) { // from caffe if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format @@ -236,58 +239,51 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { } else { MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format << weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); + status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); } - } else if (weightTensor->format == schema::Format_KHWC) { // from onnx - return RET_OK; - // if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { - // status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); - // } else { - // status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); - // } - } else if (weightTensor->format == schema::Format_HWCK) { // from tf - return 0; - } else { + } else if (weightTensor->format != schema::Format_KHWC) { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; return -1; } if (status == 0) { node->primitive->value.AsConv2D()->format = schema::Format_NHWC; - weightTensor->format = schema::Format_HWCK; + weightTensor->format = schema::Format_KHWC; } else { - MS_LOG(WARNING) << "TransFilter %sToHWCK failed, node : " - << (weightTensor->format == schema::Format_KCHW ? "KCHW" : "KHWC"), - node->name.c_str(); + MS_LOG(WARNING) << "TransFilter %sToKHWC failed, node : " + << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "KCHW") << node->name.c_str(); // todo(00445839): consider varible weight condition } - } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK + } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be KHWC if (weightTensor->format == schema::Format_CKHW) { // from caffe if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { - MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, - weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kCKHW2HWCK); + MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format + << "datatype: " << weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); + } else if (weightTensor->dataType == kNumberTypeUInt8) { + MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format + << "datatype: " << weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); } else { - MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format, - weightTensor->dataType; - status = TransFilterFormat(weightTensor.get(), kCKHW2HWCK); + MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format + << "datatype: " << weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); } - } else if (weightTensor->format == schema::Format_HWCK) { // from tf - return 0; } else if (weightTensor->format == schema::Format_CHWK) { // from onnx - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + if (weightTensor->dataType == kNumberTypeInt8) { + MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format + << "datatype: " << weightTensor->dataType; status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); - MS_LOG(DEBUG) << node->name << " weight trans format: CHWK->KHWC"; + } else if (weightTensor->dataType == kNumberTypeUInt8) { + MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format + << "datatype: " << weightTensor->dataType; + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } else { + MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format + << "datatype: " << weightTensor->dataType; status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } - } else if (weightTensor->format == schema::Format_KCHW) { - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { - status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); - } else { - status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); - } - } else { + } else if (weightTensor->format != schema::Format_KHWC) { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; return -1; } @@ -295,14 +291,13 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC; weightTensor->format = schema::Format_KHWC; } else { - MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : " - << (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"), - node->name.c_str(); + MS_LOG(WARNING) << "TransFilter" << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "CKHW") + << "To KHWC failed, node : " << node->name.c_str(); // todo(00445839): consider varible weight condition } - } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be HWCK - node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW; - weightTensor->format = schema::Format_CKHW; + } else { // weight should be HWCK + node->primitive->value.AsDeConv2D()->format = schema::Format_NHWC; + weightTensor->format = schema::Format_KHWC; } return 0; } @@ -354,7 +349,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { if (graphNode->subGraph->fmkType == converter::FmkType_MS) { weightTensor->format = schema::Format_CKHW; } - if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms + if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms status = TransFilterFormat(weightTensor.get(), kCKHW2KHWC); } else if (weightTensor->format == schema::Format_KCHW) { status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); @@ -374,8 +369,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) { } else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms status = TransFilterFormat(weightTensor.get(), kKCHW2KHWC); - } else if (weightTensor->format == schema::Format_CHWK) { // from tf - status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); + } else if (weightTensor->format == schema::Format_KHWC) { // from tf + status = RET_OK; } else { MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format; return -1; diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index f9014fbc4c..02ebf0f1d8 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -40,7 +40,8 @@ class ModelParser { } return Fb2Anf(Parse(modelFile, weightFile)); } - virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) = 0; + virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType = QuantType_QUANT_NONE) = 0; public: static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) { diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 7378bb1f99..d9f9f04773 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -31,7 +31,8 @@ CaffeModelParser::~CaffeModelParser() {} const std::set CaffeModelParser::skipedLayerType = {"Dropout"}; -schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { +schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType) { std::unique_ptr graph(new schema::MetaGraphT()); if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { @@ -91,7 +92,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const // ConvertCaffeBatchNorm(graph.get()); return graph.release(); - // return Fb2Anf(graph.release()); + // return Fb2Anf(graph.release()); } STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index 52297d3018..5f24a80600 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -33,7 +33,8 @@ class CaffeModelParser : public ModelParser { virtual ~CaffeModelParser(); - MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override; + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType = QuantType_QUANT_NONE) override; private: void ConvertCaffeBatchNorm(MetaGraphT *meta_graphT); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index f179082a70..e6527246a5 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -37,7 +37,8 @@ class OnnxModelParser : public ModelParser { public: OnnxModelParser(); virtual ~OnnxModelParser(); - MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override; + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType = QuantType_QUANT_NONE) override; private: TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 253a94c1e3..8c7163496f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -20,7 +20,6 @@ #include "tools/common/graph_util.h" #include "tools/common/storage.h" #include "flatbuffers/flatbuffers.h" -#include "utils/log_adapter.h" #include "src/common/file_utils.h" namespace mindspore { @@ -60,42 +59,64 @@ STATUS TfliteModelParser::SetAllTensors(const TensorCache &tensor_cache, schema: } return RET_OK; } +void TfliteModelParser::SetMsTensorFromTflite(const std::unique_ptr &tflite_tensor, + schema::TensorT *tensor) { + std::unique_ptr quant_param(new QuantParamT()); + if (!tflite_tensor->quantization->scale.empty()) { + quant_param->scale = tflite_tensor->quantization->scale[0]; + } -STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op) { - auto dst_op = tfliteOpMap.at(tflite_op.get()); + if (!tflite_tensor->quantization->zero_point.empty()) { + quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0]; + } - std::vector quant_params_index; - quant_params_index.insert(quant_params_index.end(), tflite_op->inputs.begin(), tflite_op->inputs.end()); - quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end()); - for (const auto &index : quant_params_index) { - const auto &tflite_tensor = tflite_subgraph->tensors[index]; - if (tflite_tensor == nullptr) { - MS_LOG(ERROR) << "tensor with id = " << index <<" is null"; - return RET_ERROR; - } + // change quant param min to 0 to fit ms-lite ops + if (tensor->dataType == TypeId::kNumberTypeInt8) { + quant_param->zeroPoint = quant_param->zeroPoint - 128; + } + + if (!tflite_tensor->quantization->min.empty()) { + quant_param->min = tflite_tensor->quantization->min[0]; + } + + if (!tflite_tensor->quantization->max.empty()) { + quant_param->max = tflite_tensor->quantization->max[0]; + } + quant_param->inited = true; + tensor->quantParams.clear(); + tensor->quantParams.emplace_back(std::move(quant_param)); +} + +STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_op, + schema::CNodeT *op, TensorCache *tensor_cache) { + MS_ASSERT(op->outputIndex.size() == tflite_op->outputs.size()); + for (size_t i = 0; i < tflite_op->inputs.size() && i < op->inputIndex.size(); i++) { + const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(i)]; if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { continue; } - std::unique_ptr quant_param(new schema::QuantParamT()); - if (!tflite_tensor->quantization->scale.empty()) { - quant_param->scale = tflite_tensor->quantization->scale[0]; - } - - if (!tflite_tensor->quantization->zero_point.empty()) { - quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0]; + auto &inTensor = tensor_cache->GetCachedTensor().at(op->inputIndex.at(i)); + if (inTensor == nullptr) { + MS_LOG(ERROR) << "Parse tflite quant params inTensor is null"; + return RET_NULL_PTR; } - - if (!tflite_tensor->quantization->min.empty()) { - quant_param->min = tflite_tensor->quantization->min[0]; + SetMsTensorFromTflite(tflite_tensor, inTensor); + } + for (size_t i = 0; i < tflite_op->outputs.size() && i < op->outputIndex.size(); i++) { + const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->outputs.at(i)]; + if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() && + tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) { + continue; } - - if (!tflite_tensor->quantization->max.empty()) { - quant_param->max = tflite_tensor->quantization->max[0]; + auto &outTensor = tensor_cache->GetCachedTensor().at(op->outputIndex.at(i)); + if (outTensor == nullptr) { + MS_LOG(ERROR) << "Parse tflite quant params outTensor is null"; + return RET_NULL_PTR; } + SetMsTensorFromTflite(tflite_tensor, outTensor); } - dst_op->quantType = schema::QuantType_AwareTrainning; return RET_OK; } @@ -105,11 +126,15 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptroutputs) { const auto &tflite_tensor = tflite_subgraph->tensors[index]; if (tflite_tensor == nullptr) { - MS_LOG(ERROR) << "tensor with id = " << index <<" is null"; + MS_LOG(ERROR) << "tensor with id = " << index << " is null"; return RET_ERROR; } std::unique_ptr tensor(new schema::TensorT()); tensor->dataType = GetTfliteDataType(tflite_tensor->type); + // change dataType to int8 to fit ms-lite op + if (tensor->dataType == TypeId::kNumberTypeUInt8) { + tensor->dataType = TypeId::kNumberTypeInt8; + } tensor->dims = tflite_tensor->shape; tensor->nodeType = schema::NodeType_Parameter; auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT); @@ -120,7 +145,8 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, TensorCache *tensorCache) { + const std::unique_ptr &tflite_op, schema::CNodeT *op, + TensorCache *tensor_cache) { auto op_type = GetTfliteNodeType(tflite_op, tflite_model); std::vector op_inputs(tflite_op->inputs); if (op_type == "DeConv2D") { @@ -130,12 +156,11 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr &t for (const auto &tflite_index : op_inputs) { const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index]; if (tflite_tensor == nullptr) { - MS_LOG(ERROR) << "tensor with id = " << tflite_index <<" is null"; + MS_LOG(ERROR) << "tensor with id = " << tflite_index << " is null"; return RET_ERROR; } auto tensor_name = tflite_tensor->name; - auto op = tfliteOpMap[tflite_op.get()]; - unsigned int index = tensorCache->FindTensor(tensor_name); + unsigned int index = tensor_cache->FindTensor(tensor_name); if (index != -1) { op->inputIndex.push_back(index); } @@ -146,19 +171,20 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr &t STATUS TfliteModelParser::ParseOp(const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, - schema::MetaGraphT *subGraph, - mindspore::lite::TensorCache *tensorCache) { + schema::MetaGraphT *subGraph, mindspore::lite::TensorCache *tensorCache, + const QuantType &quantType) { auto i = 0; for (const auto &tflite_op : tflite_subgraph->operators) { auto opType = GetTfliteNodeType(tflite_op, tflite_model); std::unique_ptr op(new schema::CNodeT); op->name = opType + "-" + std::to_string(i++); + op->quantType = quantType; MS_LOG(INFO) << "parse op: " << op->name.c_str(); auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType); if (node_parser == nullptr) { - MS_LOG(ERROR) << "cannot find node parser, opType: "<< opType.c_str(); + MS_LOG(ERROR) << "cannot find node parser, opType: " << opType.c_str(); continue; // return RET_NULL_PTR; } @@ -172,7 +198,19 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr &tflite_ status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache); if (status != RET_OK) { - MS_LOG(ERROR) << "Set Op "<< op->name.c_str() << " Output Index Failed!"; + MS_LOG(ERROR) << "set op " << opType.c_str() << " output index failed"; + return RET_ERROR; + } + + status = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, op.get(), tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "set op " << opType.c_str() << " input index failed"; + return RET_ERROR; + } + + status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache); + if (status != RET_OK) { + MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed"; return RET_ERROR; } @@ -189,8 +227,10 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr const auto &tflite_tensor = tflite_subgraph->tensors[index]; std::unique_ptr tensor(new schema::TensorT()); tensor->format = schema::Format_NHWC; - tensor->dataType = GetTfliteDataType(tflite_tensor->type); - tensor->nodeType = schema::NodeType_ValueNode; + tensor->dataType = GetTfliteDataType(tflite_tensor->type) != TypeId::kNumberTypeUInt8 + ? GetTfliteDataType(tflite_tensor->type) + : TypeId::kNumberTypeInt8; + tensor->nodeType = schema::NodeType_Parameter; tensor->dims = tflite_tensor->shape; tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT); } @@ -212,7 +252,8 @@ void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache & } } -MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) { +MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType) { if (ValidateFileStr(modelFile, ".tflite") != RET_OK) { MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite"; return nullptr; @@ -224,7 +265,6 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st MS_LOG(ERROR) << "read tflite model failed"; return nullptr; } - if (tflite_model->subgraphs.size() != 1) { MS_LOG(ERROR) << "read tflite model subgraphs failed"; return nullptr; @@ -238,30 +278,15 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st // set dst subGraph op attr and tensor_cache. std::unique_ptr subGraph(new schema::MetaGraphT); subGraph->name = "MS_model converted by TF-Lite"; - auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache); + auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache, quantType); if (status != RET_OK) { MS_LOG(ERROR) << "ParseOp failed."; return nullptr; } - for (const auto &tflite_op : tflite_subgraph->operators) { - auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache); - if (status_tmp != RET_OK) { - MS_LOG(ERROR) << "Set Op " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Input Index Failed!"; - } - } - - for (const auto &tflite_op : tflite_subgraph->operators) { - auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op); - if (statusTmp != RET_OK) { - MS_LOG(ERROR) << "ParseTfliteQuantParams " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Failed!"; - } - } - SetGraphTensorIndex(tensorCache, subGraph.get()); SetAllTensors(tensorCache, subGraph.get()); return subGraph.release(); } } // namespace lite } // namespace mindspore - diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 2b1a8d046b..2379bd2632 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -40,22 +40,25 @@ class TfliteModelParser : public ModelParser { virtual ~TfliteModelParser(); - MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile); + MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType = QuantType_QUANT_NONE) override; private: std::unique_ptr ReadTfliteModelFromFlat(const char *buf); + void SetMsTensorFromTflite(const std::unique_ptr &tflite_tensor, schema::TensorT *tensor); + void SetInputTensor(const std::unique_ptr &tflite_subgraph, TensorCache *tensor_cache); - void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, - schema::MetaGraphT *subGraphDef); + void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef); STATUS ParseOp(const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::MetaGraphT *sub_graph, - TensorCache *tensor_cache); + TensorCache *tensor_cache, const QuantType &quantType); STATUS ParseTfliteQuantParams(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op); + const std::unique_ptr &tflite_op, schema::CNodeT *op, + TensorCache *tensor_cache); std::string GetTfliteNodeType(const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model); @@ -63,13 +66,13 @@ class TfliteModelParser : public ModelParser { STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph); STATUS SetOpOutputIdx(const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, - schema::CNodeT *op, + const std::unique_ptr &tflite_op, schema::CNodeT *op, TensorCache *tensorCache); STATUS SetOpInputIdx(const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_op, TensorCache *tensorCache); + const std::unique_ptr &tflite_op, schema::CNodeT *op, + TensorCache *tensor_cache); std::map opMap; std::map tfliteOpMap; diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index 1777de07ed..335ebcbfec 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -4,7 +4,9 @@ include_directories(${3RD_DIR}/flatbuffers/include) include_directories(${3RD_DIR}/opencv/build/include/opencv4) add_library(quantizer_mid OBJECT + ${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc new file mode 100644 index 0000000000..4dead32be9 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -0,0 +1,594 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/aware_quantizer.h" +#include +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "utils/log_adapter.h" +#include "securec/include/securec.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "src/common/utils.h" +#include "tools/converter/quantizer/calc_quant_param.h" +#include "tools/common/tensor_util.h" +#include "tools/common/converter_op_utils.h" +#include "tools/common/node_util.h" + +using std::string; +using std::vector; + +namespace mindspore::lite::quant { +struct InputArray { + std::unique_ptr quantParam; + float mMin = 0.0f; + float mMax = 0.0f; + bool narrowRange = false; + int numBits = 8; + TypeId dataType = TypeId::kTypeUnknown; + + InputArray(float mean, float stdDev, TypeId dataType = TypeId::kNumberTypeFloat) { + this->dataType = dataType; + constexpr float qmin = 0; + constexpr float qmax = 255; + mMin = (qmin - mean) / stdDev; + mMax = (qmax - mean) / stdDev; + } + + STATUS InitQuantParam() { + this->quantParam = std::make_unique(); + auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits); + if (status != RET_OK) { + return status; + } + return RET_OK; + } + + STATUS SetInputArrayQP(schema::MetaGraphT *graph, size_t inputTensorIdx) { + MS_ASSERT(graph != nullptr); + auto &tensor = graph->allTensors.at(inputTensorIdx); + MS_ASSERT(tensor != nullptr); + if (!tensor->quantParams.empty()) { + auto param = GetTensorQuantParam(tensor); + if (param != nullptr && param->inited) { + MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam"; + return RET_OK; + } + tensor->quantParams.clear(); + } + std::unique_ptr tmpQuantParam(new QuantParamT()); + tmpQuantParam->inited = this->quantParam->inited; + tmpQuantParam->scale = this->quantParam->scale; + tmpQuantParam->zeroPoint = this->quantParam->zeroPoint; + tmpQuantParam->min = this->quantParam->min; + tmpQuantParam->max = this->quantParam->max; + tensor->quantParams.push_back(std::move(tmpQuantParam)); + return RET_OK; + } +}; + +const std::array AwareQuantizer::propagatedOps = { + {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape, + schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, + schema::PrimitiveType_DetectionPostProcess}}; + +AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues, + const string &meanValues) + : FbQuantizer(graph) { + MS_ASSERT(graph != nullptr); + string::size_type sz; + const float stdValue = std::stof(stdValues, &sz); + sz = 0; + const float mean = std::stof(meanValues, &sz); + if (inputInferType == "FLOAT") { + mInputArray = new InputArray(mean, stdValue); + } else { + mInputArray = new InputArray(mean, stdValue, TypeId::kNumberTypeUInt8); + } + mInputArray->InitQuantParam(); +} + +STATUS AwareQuantizer::RemoveFakeQuant() { + // for (auto &subGraph : graphDefT->subgraphs) { + // auto status = GenerateDefaultQuantParam(subGraph.get()); + // if (status != RET_OK) { + // MS_LOGE("GenerateDefaultQuantParam failed: %d", status); + // return RET_ERROR; + // } + // for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) { + // auto *node = (*iter).get(); + // if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars && GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) { + // continue; + // } + // auto inputIndexes = node->inputIndex; + // if (inputIndexes.size() != 3) { + // MS_LOGE("invalid fakequant node's input tensors count!"); + // return RET_ERROR; + // } + // bool narrorRange; + // int numBits; + // if (GetCNodeTType(*node) == OpT_FakeQuantWithMinMaxVars) { + // narrorRange = node->attr.AsFakeQuantWithMinMaxVars()->narrowRange; + // numBits = node->attr.AsFakeQuantWithMinMaxVars()->numBits; + // } + // if (GetCNodeTType(*node) == OpT_FakeQuantWithMinMax) { + // narrorRange = false; + // numBits = 8; + // } + // + // TensorDefT *tensor0 = subGraph->allTensors.at(inputIndexes[0]).get(); + // TensorDefT *tensor1 = subGraph->allTensors.at(inputIndexes[1]).get(); + // TensorDefT *tensor2 = subGraph->allTensors.at(inputIndexes[2]).get(); + // MS_ASSERT(tensor0 != nullptr); + // MS_ASSERT(tensor1 != nullptr); + // MS_ASSERT(tensor2 != nullptr); + // // calculate quant param + // MS_ASSERT(tensor1->dataType == DataType_DT_FLOAT); + // MS_ASSERT(tensor2->dataType == DataType_DT_FLOAT); + // auto *minData = reinterpret_cast(tensor1->data.data()); + // auto *maxData = reinterpret_cast(tensor2->data.data()); + // MS_ASSERT(minData != nullptr); + // MS_ASSERT(maxData != nullptr); + // std::unique_ptr quantParam(new (std::nothrow) QuantParamT()); + // if (quantParam == nullptr) { + // MS_LOGE("new quantParam failed"); + // return RET_ERROR; + // } + // auto realMin = (double)minData[0]; + // auto realMax = (double)maxData[0]; + // status = CalQuantizationParams(quantParam.get(), realMin, realMax, narrorRange, numBits); + // if (status != RET_OK) { + // MS_LOGE("in aware quantization run CalQuantizationParams failed, node: %s", node->name.c_str()); + // return RET_ERROR; + // } + // if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT) { + // CalFakeNode(tensor0, quantParam.get()); + // } + // std::unique_ptr quantParamArray(new (std::nothrow) QuantParamArrayT()); + // if (quantParamArray == nullptr) { + // MS_LOGE("new quantParamArray failed"); + // return RET_ERROR; + // } + // quantParamArray->param.push_back(std::move(quantParam)); + // auto quantParamArrayCopy = CopyQuantParamArrayT(quantParamArray); + // if (quantParamArrayCopy == nullptr) { + // MS_LOGE("CopyQuantParamArray %s return nullptr", iter->get()->name.c_str()); + // return RET_ERROR; + // } + // node->quantParam.emplace_back(std::move(quantParamArrayCopy)); + // node->quantParam.emplace_back(nullptr); // secondInTensor and thirdInTensor are weightTensors who have no + // preNode node->quantParam.emplace_back(nullptr); node->quantParam.emplace_back(std::move(quantParamArray)); + // + // // BroadCast fakeQuantNode QuantParam + // status = BroadCastQuantParam(subGraph, *iter); + // if (status != RET_OK) { + // MS_LOGE("BroadCastQuantParam %s failed: %d", iter->get()->name.c_str(), status); + // return status; + // } + // // save post node index for SetAttrToConvolution + // auto postNodeIdxes = GetOutputNodeIdx(*subGraph, *node); + // // remove fakequantwithminmax node + // status = IsolateNode(subGraph.get(), node); + // if (status != RET_OK) { + // MS_LOGE("in aware quant IsolateNode failed!"); + // return RET_ERROR; + // } + // // set filter param to node + // if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT && !postNodeIdxes.empty()) { + // auto postNode = subGraph->nodes.at(postNodeIdxes.front()).get(); + // if (GetCNodeTType(*postNode) == OpT_Conv2D || GetCNodeTType(*postNode) == OpT_DepthwiseConv2D || + // GetCNodeTType(*postNode) == OpT_DeConv2D || GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) { + // auto status = SetAttrToConvolution(subGraph.get(), postNode); + // if (status != RET_OK) { + // MS_LOGE("in aware quant SetAttrToConvolution failed!"); + // return RET_ERROR; + // } + // } + // } + // } + // + // // remove IsolatedNode + // for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end();) { + // if ((*iter)->inputIndex.empty() && (*iter)->outputIndex.empty()) { + // iter = subGraph->nodes.erase(iter); + // } else { + // iter++; + // } + // } + // // set graphInputNode inputTensor quantParams + // MS_ASSERT(subGraph->inputIndex.size() == 1); + // for (auto graphInputIndex : subGraph->inputIndex) { + // auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()), graphInputIndex); + // for (auto nodeIdx : linkedPostIdx) { + // MS_ASSERT(subGraph->nodes.size() > nodeIdx); + // mInputArray->SetInputArrayQP(subGraph->nodes.at(nodeIdx).get()); + // } + // } + // } + return RET_OK; +} + +STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) { + MS_ASSERT(subGraph != nullptr); + for (const auto &tensor : subGraph->allTensors) { + if (!tensor->quantParams.empty()) { + continue; + } + std::unique_ptr defaultQuantParam(new QuantParamT()); + tensor->quantParams.emplace_back(std::move(defaultQuantParam)); + } + return RET_OK; +} + +STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { + // MS_ASSERT(subGraph != nullptr); + // MS_ASSERT(node != nullptr); + // auto inputIndexes = node->inputIndex; + // MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) == OpT_DepthwiseConv2D || + // GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) == OpT_DeDepthwiseConv2D); + // if (inputIndexes.size() < 2) { + // MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!", node->name.c_str(), inputIndexes.size()); + // return RET_ERROR; + // } + // TensorDefT *filterTensor = subGraph->allTensors.at(inputIndexes[1]).get(); + // MS_ASSERT(filterTensor != nullptr); + // auto filterDims = filterTensor->dims; + // MS_ASSERT(filterDims.size() == 4); + // if (GetCNodeTType(*node) == OpT_Conv2D) { + // if (node->fmkType == FmkType_MS) { + // node->attr.AsConv2D()->channelOut = (int32_t)filterDims[0]; + // node->attr.AsConv2D()->channelIn = (int32_t)filterDims[1]; + // node->attr.AsConv2D()->kernelH = (int32_t)filterDims[2]; + // node->attr.AsConv2D()->kernelW = (int32_t)filterDims[3]; + // } else if (node->fmkType == FmkType_TF) { + // node->attr.AsConv2D()->kernelH = (int32_t)filterDims[0]; + // node->attr.AsConv2D()->kernelW = (int32_t)filterDims[1]; + // node->attr.AsConv2D()->channelIn = (int32_t)filterDims[2]; + // node->attr.AsConv2D()->channelOut = (int32_t)filterDims[3]; + // } else { + // MS_LOGE("Unsupport"); + // } + // } + // if (GetCNodeTType(*node) == OpT_DepthwiseConv2D) { + // if (node->fmkType == FmkType_MS) { + // node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[0]; + // node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[1]; + // node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[2]; + // node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[3]; + // } else if (node->fmkType == FmkType_TF) { + // node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[0]; + // node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[1]; + // node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[2]; + // node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[3]; + // } else { + // MS_LOGE("Unsupport"); + // } + // } + // if (GetCNodeTType(*node) == OpT_DeConv2D) { + // MS_ASSERT(false); + // } + // if (GetCNodeTType(*node) == OpT_DeDepthwiseConv2D) { + // MS_ASSERT(false); + // } + return RET_OK; +} + +STATUS AwareQuantizer::GenerateQuantParam() { + // todo why? + MS_ASSERT(graph->inputIndex.size() == 1); + // set graphInputNode input + for (auto graphInputIndex : graph->inputIndex) { + auto status = mInputArray->SetInputArrayQP(graph.get(), graphInputIndex); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetInputArrayQP failed"; + return status; + } + } + auto status = GenerateDefaultQuantParam(graph.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "GenerateDefaultQuantParam failed"; + return status; + } + auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); + + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + MS_ASSERT(node != nullptr); + if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || + GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { + MS_ASSERT(false); + } + auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); + if (quantParamCalcer == nullptr) { + MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() + << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; + node->quantType = static_cast(QuantType_QUANT_NONE); + } else { + status = quantParamCalcer->Calc(graph.get(), *node); + if (status != RET_OK) { + MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); + node->quantType = schema::QuantType_QUANT_NONE; + } else { + node->quantType = schema::QuantType_AwareTrainning; + } + } + } + return RET_OK; +} + +STATUS AwareQuantizer::DoQuantize() { + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) { + continue; + } + if (node->quantType != schema::QuantType_AwareTrainning) { + continue; + } + STATUS status; + if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D || + GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) { + auto inputIndexes = node->inputIndex; + if (inputIndexes.size() < 2) { + MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; + return RET_ERROR; + } + // quant weight + status = QuantConvWeight(graph.get(), node.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantConvWeight failed!"; + return RET_ERROR; + } + // quant bias + if (inputIndexes.size() == 3) { + status = QuantConvBias(graph.get(), node.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantConvBias failed!"; + return RET_ERROR; + } + } + } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { + status = QuantDetectionPostProcessConstTensor(graph.get(), node.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; + return RET_ERROR; + } + } else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) { + status = QuantAddConstTensor(graph.get(), node.get()); + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantAddConstTensor failed!"; + return RET_ERROR; + } + } + const auto nodeType = GetCNodeTType(*node); + auto find = std::find(propagatedOps.begin(), propagatedOps.end(), nodeType); + if (find != propagatedOps.end()) { + auto inputTensor = graph->allTensors.at(node->inputIndex[0]).get(); + auto outputTensor = graph->allTensors.at(node->outputIndex[0]).get(); + MS_ASSERT(inputTensor != nullptr); + MS_ASSERT(outputTensor != nullptr); + outputTensor->dataType = inputTensor->dataType; + } + } + return RET_OK; +} + +STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(node != nullptr); + for (size_t i = 0; i < node->inputIndex.size(); i++) { + auto inTensorIdx = node->inputIndex.at(i); + MS_ASSERT(graph->allTensors.size() > inTensorIdx); + auto &inTensor = graph->allTensors.at(inTensorIdx); + MS_ASSERT(inTensor != nullptr); + if (inTensor->refCount == 999) { + switch (inTensor->dataType) { + case TypeId::kNumberTypeFloat: { + auto quantParam = GetTensorQuantParam(inTensor); + MS_ASSERT(quantParam != nullptr); + MS_ASSERT(quantParam->inited); + auto constTensorShapeSize = GetShapeSize(*(inTensor.get())); + vector qDatas(constTensorShapeSize); + void *inData = inTensor->data.data(); + auto *castedInData = static_cast(inData); + for (size_t j = 0; j < constTensorShapeSize; j++) { + qDatas[j] = QuantizeData(castedInData[j], quantParam.get()); + } + inTensor->data = std::move(qDatas); + inTensor->dataType = kNumberTypeUInt8; + } break; + case kNumberTypeUInt8: + break; + default: + // MS_LOGE("Unsupported dataType: %d", inTensor->dataType); + return RET_ERROR; + } + } + } + return RET_OK; +} + +STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]); + MS_ASSERT(constTensor != nullptr); + const auto *constData = reinterpret_cast(constTensor->data.data()); + + if (constTensor->refCount == 999 && constTensor->dataType == TypeId::kNumberTypeFloat) { + size_t constTensorShapeSize = GetShapeSize(*constTensor); + std::unique_ptr quantParam = GetTensorQuantParam(constTensor); + if (quantParam == nullptr) { + // MS_LOGE("new QuantParamT failed"); + return RET_NULL_PTR; + } + vector qDatas(constTensorShapeSize); + for (size_t j = 0; j < constTensorShapeSize; j++) { + float rawData = constData[j]; + qDatas[j] = QuantizeData(rawData, quantParam.get()); + } + constTensor->data = std::move(qDatas); + constTensor->dataType = TypeId::kNumberTypeUInt8; + } + return RET_OK; +} + +STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) { + MS_ASSERT(graph != nullptr); + MS_ASSERT(node != nullptr); + auto inputIndexes = node->inputIndex; + MS_ASSERT(inputIndexes.size() >= 3); + MS_ASSERT(graph->allTensors.size() > inputIndexes.at(0)); + MS_ASSERT(graph->allTensors.size() > inputIndexes.at(1)); + MS_ASSERT(graph->allTensors.size() > inputIndexes.at(2)); + auto &biasTensor = graph->allTensors.at(inputIndexes.at(2)); + MS_ASSERT(biasTensor != nullptr); + if (biasTensor->dataType != TypeId::kNumberTypeFloat) { + // MS_LOGD("conv %s's bias data is not float", node->name.c_str()); + return RET_OK; + } + + if (biasTensor->dataType == TypeId::kNumberTypeInt32) { + return RET_OK; + } + if (biasTensor->dataType != TypeId::kNumberTypeFloat) { + // MS_LOGE("conv %s's bias data is not float", node->name.c_str()); + return RET_ERROR; + } + auto &inputTensor = graph->allTensors.at(inputIndexes.at(0)); + auto &weightTensor = graph->allTensors.at(inputIndexes.at(1)); + + MS_ASSERT(inputTensor != nullptr); + MS_ASSERT(weightTensor != nullptr); + auto inputScale = inputTensor->quantParams.front()->scale; + auto weightScale = weightTensor->quantParams.front()->scale; + auto scale = inputScale * weightScale; + // set bias quant param + std::unique_ptr biasQuantParam = GetTensorQuantParam(biasTensor); + if (biasQuantParam == nullptr) { + // MS_LOGE("new QuantParamT failed"); + return RET_ERROR; + } + biasQuantParam->inited = true; + biasQuantParam->scale = scale; + biasQuantParam->zeroPoint = 0; + biasQuantParam->numBits = 8; + biasQuantParam->narrowRange = false; + biasQuantParam->min = 0.0; + biasQuantParam->max = 0.0; + + // quant bias data + auto bShapeSize = GetShapeSize(*(biasTensor.get())); + auto *qDatas = new (std::nothrow) int32_t[bShapeSize]; + if (qDatas == nullptr) { + // MS_LOGE("new qDatas failed"); + return RET_ERROR; + } + void *biasData = biasTensor->data.data(); + auto *rawDatas = static_cast(biasData); + for (size_t i = 0; i < bShapeSize; ++i) { + qDatas[i] = (int32_t)std::round(rawDatas[i] / scale); + } + biasTensor->dataType = TypeId::kNumberTypeInt32; + biasTensor->data.clear(); + biasTensor->data.resize(bShapeSize * sizeof(int32_t)); + auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas, bShapeSize * sizeof(int32_t)); + if (ret != EOK) { + // MS_LOGE("memcpy_s failed: %d", ret); + return RET_ERROR; + } + delete[] qDatas; + return RET_OK; +} + +STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { + MS_ASSERT(subGraph != nullptr); + MS_ASSERT(node != nullptr); + MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size()); + auto inputIndexes = node->inputIndex; + MS_ASSERT(inputIndexes.size() >= 2); + MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1)); + auto &weightTensor = subGraph->allTensors.at(inputIndexes.at(1)); + if (weightTensor->dataType == TypeId::kNumberTypeInt8) { + return RET_OK; + } + if (weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) { + MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; + return RET_ERROR; + } + size_t wShapeSize = GetShapeSize(*(weightTensor.get())); + void *oriWeightData = weightTensor->data.data(); + MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); + vector qDatas(wShapeSize); + auto weightQauntParam = GetTensorQuantParam(weightTensor); + if (weightTensor->dataType == TypeId::kNumberTypeFloat) { // normal awareing quant + auto *weightData = static_cast(oriWeightData); + for (size_t j = 0; j < wShapeSize; j++) { + qDatas[j] = QuantizeData(weightData[j], weightQauntParam.get()); + } + } else { // tflite awareing quant + auto *weightData = static_cast(oriWeightData); + for (size_t j = 0; j < wShapeSize; j++) { + qDatas[j] = (int32_t)weightData[j] - 128; + } + weightQauntParam->zeroPoint -= 128; + weightTensor->quantParams.clear(); + weightTensor->quantParams.emplace_back(weightQauntParam.release()); + } + + ::memcpy(weightTensor->data.data(), qDatas.data(), wShapeSize); + weightTensor->dataType = TypeId::kNumberTypeInt8; + return RET_OK; +} +STATUS AwareQuantizer::DetermineNodeQuantType() { + MS_ASSERT(graph != nullptr); + for (auto &node : graph->nodes) { + MS_ASSERT(node != nullptr); + bool canQuant = true; + for (auto &inTensorIdx : node->inputIndex) { + MS_ASSERT(graph->allTensors.size() > inTensorIdx); + auto &inTensor = graph->allTensors.at(inTensorIdx); + MS_ASSERT(inTensor != nullptr); + if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || + !inTensor->quantParams.front()->inited) { + canQuant = false; + break; + } + } + + if (canQuant) { + for (auto &outTensorIdx : node->outputIndex) { + MS_ASSERT(graph->allTensors.size() > outTensorIdx); + auto &outTensor = graph->allTensors.at(outTensorIdx); + MS_ASSERT(outTensor != nullptr); + if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || + !outTensor->quantParams.front()->inited) { + canQuant = false; + break; + } + } + } + if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { + node->quantType = schema::QuantType_AwareTrainning; + } else { + node->quantType = schema::QuantType_QUANT_NONE; + } + } + return RET_OK; +} +} // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h new file mode 100644 index 0000000000..b2cf74e0b6 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MS_AWARE_QUANTIZER_H +#define MS_AWARE_QUANTIZER_H + +#include +#include +#include "tools/converter/quantizer/quantizer.h" +#include "schema/inner/model_generated.h" +#include "include/errorcode.h" + +namespace mindspore::lite::quant { +struct InputArray; + +class AwareQuantizer : public FbQuantizer { + public: + AwareQuantizer(schema::MetaGraphT *graph, const std::string &inputInferType, const std::string &stdValues, + const std::string &meanValues); + + ~AwareQuantizer() { delete (mInputArray); } + + STATUS RemoveFakeQuant() override; + + STATUS GenerateQuantParam() override; + + STATUS DetermineNodeQuantType() override; + + STATUS DoQuantize() override; // override; + + private: + // RemoveFakeQuant + STATUS SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node); + + STATUS GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph); + + STATUS QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node); + + STATUS QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node); + + STATUS QuantConvBias(const schema::MetaGraphT *graph, schema::CNodeT *node); + + STATUS QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node); + + float inputScale = 0.0f; + + InputArray *mInputArray; + + static const std::array propagatedOps; +}; +} // namespace mindspore::lite::quant +#endif diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc new file mode 100644 index 0000000000..a35b9ad350 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -0,0 +1,504 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/calc_quant_param.h" +#include +#include +#include +#include +#include "tools/common/graph_util.h" +#include "tools/common/tensor_util.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "schema/inner/ops_generated.h" +#include "src/common/utils.h" + +namespace mindspore::lite { +STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, QuantParamT *quantParam) { + MS_ASSERT(quantParam != nullptr); + // int32 weight no need to quant + if (tensor.dataType == TypeId::kNumberTypeInt32 || tensor.dataType == TypeId::kNumberTypeUInt8) { + return RET_OK; + } + if (tensor.dataType != TypeId::kNumberTypeFloat) { + // MS_LOGW("Const Tensor without quantParam should has float dataType, in fact: %d", tensor.dataType); + return RET_ERROR; + } + const auto *constData = reinterpret_cast(tensor.data.data()); + size_t constTensorShapeSize = GetShapeSize(tensor); + float min = 0.0f; + float max = 0.0f; + // find min and max + for (size_t i = 0; i < constTensorShapeSize; i++) { + min = std::min(min, constData[i]); + max = std::max(max, constData[i]); + } + if (min == 0.0f && max == 0.0f) { + max = 1.0f; + } + bool isQuantExact = true; + for (size_t i = 0; i < constTensorShapeSize; i++) { + isQuantExact &= (constData[i] == min || constData[i] == max); + } + if (!isQuantExact) { + // //MS_LOGD("compute quantParam for const tensor may be a cause of poor inference accuracy"); + } + return quant::CalQuantizationParams(quantParam, min, max); +} + +// init inTensor quantParam from preNode if possable +// init outTensor quantParam from postNode if possable +int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { + MS_ASSERT(node.inputIndex.size() > 0); + MS_ASSERT(node.quantParam.size() == node.inputIndex.size() + node.outputIndex.size()); + inputParamDone = 0; + auto inputTensorSize = node.inputIndex.size(); + for (size_t i = 0; i < inputTensorSize; i++) { + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); + auto &tensor = graph->allTensors.at(node.inputIndex.at(i)); + MS_ASSERT(tensor != nullptr); + auto quantParam = GetTensorQuantParam(tensor); + if (quantParam->inited) { // inited + inputParamDone++; + continue; + } + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); + + MS_ASSERT(tensor != nullptr); + if (tensor->refCount == schema::NodeType_ValueNode && !IsContain(graph->inputIndex, node.inputIndex.at(i))) { + auto status = ComputeConstQuantParam((*tensor), quantParam.get()); + if (status != RET_OK) { + // MS_LOGW("ComputeConstQuantParam failed: %d", status); + return status; + } + tensor->quantParams.front() = std::move(quantParam); + inputParamDone++; + continue; + } + } + outputParamDone = 0; + for (unsigned int i : node.outputIndex) { + MS_ASSERT(graph->allTensors.size() > i); + auto &tensor = graph->allTensors.at(i); + MS_ASSERT(tensor != nullptr); + auto quantParam = GetTensorQuantParam(tensor); + MS_ASSERT(quantParam != nullptr); + if (quantParam->inited) { // inited + outputParamDone++; + continue; + } + + if (tensor->refCount == 999) { + MS_ASSERT(false); + } + } + return RET_OK; +} + +int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) { + auto status = QuantParamCalcer::Calc(subGraph, node); + if (status != RET_OK) { + // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); + return status; + } + if (inputParamDone != node.inputIndex.size()) { + MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name.c_str(); + return RET_ERROR; + } + if (outputParamDone != node.outputIndex.size()) { + MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name.c_str(); + return RET_ERROR; + } + return RET_OK; +} + +int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { + auto status = QuantParamCalcer::Calc(graph, node); + if (status != RET_OK) { + // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); + return status; + } + if (inputParamDone != node.inputIndex.size()) { + MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(0)); + auto &outTensor = graph->allTensors.at(node.outputIndex.at(0)); + MS_ASSERT(outTensor != nullptr); + auto outputQuantParam = GetTensorQuantParam(outTensor); + MS_ASSERT(outputQuantParam != nullptr); + if (!outputQuantParam->inited) { + // MS_LOGW("Can not determine inputTensor quantParam from outputTensor for node %s", node.name.c_str()); + return RET_ERROR; + } + for (unsigned int i : node.inputIndex) { + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); + auto &inTensor = graph->allTensors.at(i); + MS_ASSERT(inTensor != nullptr); + auto inQuantParam = GetTensorQuantParam(inTensor); + if (inQuantParam->inited) { + continue; + } + inTensor->quantParams.front() = std::move(inQuantParam); + } + } + if (outputParamDone != node.outputIndex.size()) { + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); + auto &inTensor = graph->allTensors.at(node.inputIndex.at(0)); + MS_ASSERT(inTensor != nullptr); + auto inQuantParam = GetTensorQuantParam(inTensor); + if (!inQuantParam->inited) { + // MS_LOGW("Can not determine outputTensor quantParam from inputTensor for node %s", node.name.c_str()); + return RET_ERROR; + } + for (size_t i = 0; i < node.outputIndex.size(); i++) { + MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(i)); + auto &outTensor = graph->allTensors.at(node.outputIndex.at(i)); + MS_ASSERT(outTensor != nullptr); + auto outQuantParam = GetTensorQuantParam(outTensor); + if (outQuantParam->inited) { + continue; + } + // todo copy quant params + outTensor->quantParams.front() = std::move(outQuantParam); + } + } + return RET_OK; +} + +class CalcConcat : public QuantParamCalcer { + public: + CalcConcat() = default; + + int Calc(MetaGraphT *graph, const CNodeT &node) override { + MS_ASSERT(node.outputIndex.size() == 1); + auto status = QuantParamCalcer::Calc(graph, node); + if (status != RET_OK) { + // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); + return status; + } + + if (inputParamDone != node.inputIndex.size()) { + // MS_LOGW("Can not determine concat inputTensor quantParam, node %s", node.name.c_str()); + return RET_ERROR; + } + + if (outputParamDone != 1) { + MS_ASSERT(outputParamDone == 0); + float minMin = FLT_MAX; + float maxMax = FLT_MIN; + bool narrowRange = false; + int numBits = -1; + for (size_t i = 0; i < node.inputIndex.size(); i++) { + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); + auto &inTensor = graph->allTensors.at(i); + MS_ASSERT(inTensor != nullptr); + auto inQuantParam = GetTensorQuantParam(inTensor); + MS_ASSERT(inQuantParam != nullptr); + if (!inQuantParam->inited) { + return RET_ERROR; + } + if (numBits == -1) { + narrowRange = inQuantParam->narrowRange; + numBits = inQuantParam->numBits; + } else { + MS_ASSERT(narrowRange == quantParam->narrowRange); + MS_ASSERT(numBits == quantParam->numBits); + } + if (minMin > inQuantParam->min) { + minMin = inQuantParam->min; + } + if (maxMax < inQuantParam->max) { + maxMax = inQuantParam->max; + } + } + + MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); + auto &outTensor = graph->allTensors.at(node.outputIndex.front()); + MS_ASSERT(outTensor != nullptr); + auto outQuantParam = GetTensorQuantParam(outTensor); + + status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits); + if (status != RET_OK) { + // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); + return RET_ERROR; + } + outputParamDone++; + } + + return RET_OK; + } +}; + +class CalcAdd : public QuantParamCalcer { + public: + CalcAdd() = default; + + int Calc(MetaGraphT *graph, const CNodeT &node) override { + MS_ASSERT(node.inputIndex.size() == 2); + MS_ASSERT(node.outputIndex.size() == 1); + auto status = QuantParamCalcer::Calc(graph, node); + if (status != RET_OK) { + // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); + return status; + } + + if (inputParamDone != 2) { + // MS_LOGW("Can not determine add inputTensor quantParam, node %s", node.name.c_str()); + return RET_ERROR; + } + if (outputParamDone != 1) { + MS_ASSERT(outputParamDone == 0); + MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); + auto &outTensor = graph->allTensors.at(node.outputIndex.front()); + MS_ASSERT(outTensor != nullptr); + auto outQuantParam = GetTensorQuantParam(outTensor); + + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); + auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0)); + MS_ASSERT(tensor0 != nullptr); + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1)); + auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1)); + MS_ASSERT(tensor1 != nullptr); + auto biasTensor = &tensor0; + auto paramTensor = &tensor1; + if (tensor0->refCount == 999 && (tensor0->dims.empty() || tensor0->dims.size() == 1)) { + biasTensor = &tensor0; + paramTensor = &tensor1; + } else if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { + biasTensor = &tensor1; + paramTensor = &tensor0; + } else { + // MS_LOGW("Can not determine add outputTensor quantParam, node %s", node.name.c_str()); + return RET_ERROR; + } + auto quantParam = GetTensorQuantParam(*paramTensor); + MS_ASSERT(quantParam != nullptr); + MS_ASSERT(quantParam->inited); + auto min = quantParam->min; + auto max = quantParam->max; + { + if ((*biasTensor)->dataType == TypeId::kNumberTypeFloat) { + MS_ASSERT((*biasTensor)->data.size() == sizeof(float) / sizeof(uint8_t)); + void *oriTensorData = (*biasTensor)->data.data(); + auto *bias = static_cast(oriTensorData); + status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias)); + if (status != RET_OK) { + // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); + return RET_ERROR; + } + } else if ((*biasTensor)->dataType == TypeId::kNumberTypeUInt8) { + MS_ASSERT((*biasTensor)->data.size() == 1); + void *oriTensorData = (*biasTensor)->data.data(); + auto *bias = static_cast(oriTensorData); + status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias)); + if (status != RET_OK) { + // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); + return RET_ERROR; + } + } else { + // MS_LOGW("Unsupported tensor dataType: %d", (*biasTensor)->dataType); + return RET_ERROR; + } + } + } + return RET_OK; + } +}; + +class CalcRealDiv : public QuantParamCalcer { + public: + CalcRealDiv() = default; + + int Calc(MetaGraphT *graph, const CNodeT &node) override { + MS_ASSERT(node.inputIndex.size() == 2); + MS_ASSERT(node.outputIndex.size() == 1); + auto status = QuantParamCalcer::Calc(graph, node); + if (status != RET_OK) { + // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); + return status; + } + + if (inputParamDone != 2) { + // MS_LOGW("Can not determine realdiv inputTensor quantParam, node %s", node.name.c_str()); + return RET_ERROR; + } + if (outputParamDone != 1) { + MS_ASSERT(outputParamDone == 0); + MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); + auto &outTensor = graph->allTensors.at(node.outputIndex.front()); + MS_ASSERT(outTensor != nullptr); + auto outQuantParam = GetTensorQuantParam(outTensor); + + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); + auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0)); + MS_ASSERT(tensor0 != nullptr); + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1)); + auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1)); + MS_ASSERT(tensor1 != nullptr); + if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) { + auto quantParam = GetTensorQuantParam(tensor1); + auto min = quantParam->min; + auto max = quantParam->max; + { + if (tensor1->dataType == TypeId::kNumberTypeFloat) { + MS_ASSERT(tensor1->data.size() == sizeof(float) / sizeof(uint8_t)); + void *oriTensorData = tensor1->data.data(); + auto *div = static_cast(oriTensorData); + MS_ASSERT(*div != 0); + status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max / (*div)); + if (status != RET_OK) { + // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); + return RET_ERROR; + } + } else if (tensor1->dataType == TypeId::kNumberTypeUInt8) { + MS_ASSERT(tensor1->data.size() == 1); + void *oriTensorData = tensor1->data.data(); + auto *div = static_cast(oriTensorData); + status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max + (*div)); + if (status != RET_OK) { + // MS_LOGW("in aware quantization run CalQuantizationParams failed!"); + return RET_ERROR; + } + } else { + // MS_LOGW("Unsupported tensor dataType: %d", tensor1->dataType); + return RET_ERROR; + } + } + } else { + // MS_LOGW("Can not determine realDiv outputTensor quantParam, node %s", node.name.c_str()); + return RET_ERROR; + } + } + return RET_OK; + } +}; + +class CalcToSet : public QuantParamCalcer { + public: + CalcToSet(float min, float max) : min(min), max(max) {} + + int Calc(MetaGraphT *graph, const CNodeT &node) override { + MS_ASSERT(node.inputIndex.size() == 1); + MS_ASSERT(node.outputIndex.size() == 1); + auto status = QuantParamCalcer::Calc(graph, node); + if (status != RET_OK) { + // MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status); + return status; + } + // input + if (inputParamDone != node.inputIndex.size()) { + // MS_LOGW("Can not determine inputTensor quantParam, node %s", node.name.c_str()); + return RET_ERROR; + } + // output + std::unique_ptr quantParam(new (std::nothrow) QuantParamT()); + if (quantParam == nullptr) { + // MS_LOGW("new QuantParamT failed"); + return RET_ERROR; + } + quantParam->scale = (max - min) / 256; + MS_ASSERT(quantParam->scale != 0); + quantParam->zeroPoint = int32_t(std::round(256 - max / quantParam->scale)); + quantParam->min = min; + quantParam->max = max; + quantParam->inited = true; + MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); + auto &outTensor = graph->allTensors.at(node.outputIndex.front()); + MS_ASSERT(outTensor != nullptr); + outTensor->quantParams.front() = std::move(quantParam); + return RET_OK; + } + + protected: + float min; + float max; +}; + +class CalcActivation : public QuantParamCalcer { + public: + CalcActivation() = default; + + int Calc(MetaGraphT *subGraph, const CNodeT &node) override { + MS_ASSERT(node.inputIndex.size() == 1); + MS_ASSERT(node.outputIndex.size() == 1); + MS_ASSERT(node.attr.AsActivation() != nullptr); + if (node.primitive->value.AsActivation()->type == schema::ActivationType_SIGMOID) { + auto calcToSet = CalcToSet(0, 1); + return calcToSet.Calc(subGraph, node); + } else { + auto calCommon = CommonCalcer(); + return calCommon.Calc(subGraph, node); + } + } +}; + +QuantParamCalcRegister::QuantParamCalcRegister() { + bool hasError = false; + auto baseCalcer = new (std::nothrow) QuantParamCalcer(); + if (baseCalcer == nullptr) { + // MS_LOGW("new QuantParamCalcer failed"); + hasError = true; + } + auto commonCalcer = new (std::nothrow) CommonCalcer(); + if (commonCalcer == nullptr) { + // MS_LOGW("new commonCalcer failed"); + hasError = true; + } + auto linearCalcer = new (std::nothrow) LinearCalcer(); + if (linearCalcer == nullptr) { + // MS_LOGW("new linearCalcer failed"); + hasError = true; + } + if (!hasError) { + _registerMap[schema::PrimitiveType_Concat] = new CalcConcat(); + _registerMap[schema::PrimitiveType_Activation] = new CalcActivation(); + _registerMap[schema::PrimitiveType_Add] = new CalcAdd(); + _registerMap[schema::PrimitiveType_Mul] = commonCalcer; + _registerMap[schema::PrimitiveType_Conv2D] = commonCalcer; + _registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer; + _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; + _registerMap[schema::PrimitiveType_Resize] = linearCalcer; + _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; + _registerMap[schema::PrimitiveType_Shape] = linearCalcer; // todo if shape influence postNode's output quantParam + _registerMap[schema::PrimitiveType_SoftMax] = new CalcToSet(0, 1); + _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer; + _registerMap[schema::PrimitiveType_RealDiv] = new CalcRealDiv(); + _registerMap[schema::PrimitiveType_Reduce] = commonCalcer; + _registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer; + _registerMap[schema::PrimitiveType_Mean] = linearCalcer; + _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; + _registerMap[schema::PrimitiveType_MatMul] = commonCalcer; + _registerMap[schema::PrimitiveType_FullConnection] = commonCalcer; + _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; + _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; + // todo + // detection_postprocess op's quant param will not infer only fetch from preNode or postNode + // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. + // if quantTransNode is inserted after detection_postprocess node, there will be some errors + _registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer; + } +} + +QuantParamCalcRegister *QuantParamCalcRegister::GetInstance() { + static QuantParamCalcRegister instance; + return &instance; +} + +QuantParamCalcer *QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) { + auto it = _registerMap.find(opType); + if (it != _registerMap.end()) { + return it->second; + } + return nullptr; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.h b/mindspore/lite/tools/converter/quantizer/calc_quant_param.h new file mode 100644 index 0000000000..31455f30d2 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CALC_QUANT_PARAM_H +#define CALC_QUANT_PARAM_H + +#include +#include +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +static constexpr int CONVLUTION_INPUT_NUM = 3; + +class QuantParamCalcer { + public: + virtual ~QuantParamCalcer() = default; + virtual int Calc(schema::MetaGraphT *graph, const schema::CNodeT &node); + + protected: + STATUS ComputeConstQuantParam(const schema::TensorT &tensor, schema::QuantParamT *quantParam); + + protected: + size_t inputParamDone = 0; + size_t outputParamDone = 0; +}; + +class CommonCalcer : public QuantParamCalcer { + public: + CommonCalcer() = default; + ~CommonCalcer() override = default; + int Calc(schema::MetaGraphT *subGraph, const schema::CNodeT &node) override; +}; + +class LinearCalcer : public QuantParamCalcer { + public: + LinearCalcer() = default; + ~LinearCalcer() override = default; + int Calc(schema::MetaGraphT *graph, const schema::CNodeT &node) override; +}; + +class QuantParamCalcRegister { + public: + virtual ~QuantParamCalcRegister() = default; + QuantParamCalcer *GetQuantParamCalcer(schema::PrimitiveType opType); + static QuantParamCalcRegister *GetInstance(); + + private: + QuantParamCalcRegister(); + std::unordered_map _registerMap; +}; +} // namespace lite +} // namespace mindspore + +#endif diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 9fc5e55df0..26d5fce229 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -39,126 +39,127 @@ QuantStrategy::QuantStrategy(size_t weightSize, size_t convWeightQuantChannelThr : mWeightSize(weightSize), mConvWeightQuantChannelThreshold(convWeightQuantChannelThreshold) {} bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { - size_t i = 0; - for (i = 0; i < mConvTypes.size(); i++) { - if (node->fullname_with_scope().find(mConvTypes[i]) == 0) { - break; - } + size_t i = 0; + for (i = 0; i < mConvTypes.size(); i++) { + if (node->fullname_with_scope().find(mConvTypes[i]) == 0) { + break; } + } - if ((i == mConvTypes.size()) || (node->size() < 3)) { - return false; - } + if ((i == mConvTypes.size()) || (node->size() < 3)) { + return false; + } - auto inputNode = node->input(2); - if (!inputNode->isa()) { - return false; - } - auto paramNode = inputNode->cast(); - auto abstract_base = paramNode->abstract(); - if (abstract_base == nullptr) { - return false; - } + auto inputNode = node->input(2); + if (!inputNode->isa()) { + return false; + } + auto paramNode = inputNode->cast(); + auto abstract_base = paramNode->abstract(); + if (abstract_base == nullptr) { + return false; + } - if (!utils::isa(abstract_base->GetShapeTrack())) { - MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); - return false; - } - auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); - size_t shapeSize = 1; - for (auto dim : weight_shape) { - shapeSize = shapeSize * dim; - } - if (shapeSize < mWeightSize) { - MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; - return false; - } - if (weight_shape[0] <= mConvWeightQuantChannelThreshold) { - MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0]; - return false; - } + if (!utils::isa(abstract_base->GetShapeTrack())) { + MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); + return false; + } + auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); + size_t shapeSize = 1; + for (auto dim : weight_shape) { + shapeSize = shapeSize * dim; + } + if (shapeSize < mWeightSize) { + MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; + return false; + } + if (weight_shape[0] <= mConvWeightQuantChannelThreshold) { + MS_LOG(INFO) << "channel less mConvWeightQuantChannelThreshold!" << weight_shape[0]; + return false; + } - return true; + return true; } bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { - if (!node->isa()) { - return false; - } - auto cnode = std::dynamic_pointer_cast(node); + if (!node->isa()) { + return false; + } + auto cnode = std::dynamic_pointer_cast(node); - auto primitiveT_value = GetValueNode>(cnode->input(0)); - if (primitiveT_value == nullptr) { - MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); - return false; - } + auto primitiveT_value = GetValueNode>(cnode->input(0)); + if (primitiveT_value == nullptr) { + MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); + return false; + } - auto type = primitiveT_value->GetPrimitiveT()->value.type; - MS_LOG(INFO) << "Primitive type: " << type; - static const std::vector uint8OpList = { - schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, - schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, - schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape, - schema::PrimitiveType_Activation}; - return IsContain(uint8OpList, type); + auto type = primitiveT_value->GetPrimitiveT()->value.type; + MS_LOG(INFO) << "Primitive type: " << type; + static const std::vector uint8OpList = { + schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, + schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, + schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, + schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape, + schema::PrimitiveType_Activation}; + return IsContain(uint8OpList, type); } bool QuantStrategy::CanMulOpQuantized(const CNodePtr &node) const { - size_t i = 0; - for (i = 0; i < mMulTypes.size(); i++) { - if (node->fullname_with_scope().find(mMulTypes[i]) == 0) { - break; - } - } - if (i == mMulTypes.size()) { - return false; + size_t i = 0; + for (i = 0; i < mMulTypes.size(); i++) { + if (node->fullname_with_scope().find(mMulTypes[i]) == 0) { + break; } + } + if (i == mMulTypes.size()) { + return false; + } - if (node->size() < 3) { - MS_LOG(INFO) << "input size less!"; - return false; - } + if (node->size() < 3) { + MS_LOG(INFO) << "input size less!"; + return false; + } - auto inputNode1 = node->input(1); - auto inputNode2 = node->input(2); - if (inputNode1 == nullptr || inputNode2 == nullptr) { - MS_LOG(INFO) << "mul input is nullptr!"; - return false; - } + auto inputNode1 = node->input(1); + auto inputNode2 = node->input(2); + if (inputNode1 == nullptr || inputNode2 == nullptr) { + MS_LOG(INFO) << "mul input is nullptr!"; + return false; + } - ParameterPtr paramNode = nullptr; - if (inputNode1->isa()) { - paramNode = inputNode1->cast(); - } else if (inputNode2->isa()) { - paramNode = inputNode2->cast(); - } + ParameterPtr paramNode = nullptr; + if (inputNode1->isa()) { + paramNode = inputNode1->cast(); + } else if (inputNode2->isa()) { + paramNode = inputNode2->cast(); + } - if (paramNode == nullptr) { - MS_LOG(INFO) << "invalid paramNode!"; - return false; - } + if (paramNode == nullptr) { + MS_LOG(INFO) << "invalid paramNode!"; + return false; + } - auto abstract_base = paramNode->abstract(); - if (abstract_base == nullptr) { - MS_LOG(INFO) << "abstract is nullptr"; - return false; - } + auto abstract_base = paramNode->abstract(); + if (abstract_base == nullptr) { + MS_LOG(INFO) << "abstract is nullptr"; + return false; + } - if (!utils::isa(abstract_base->GetShapeTrack())) { - MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); - return false; - } - auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); - size_t shapeSize = 1; - for (auto dim : weight_shape) { - shapeSize = shapeSize * dim; - } - if (shapeSize < mWeightSize) { - MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; - return false; - } + if (!utils::isa(abstract_base->GetShapeTrack())) { + MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << paramNode->name(); + return false; + } + auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); + size_t shapeSize = 1; + for (auto dim : weight_shape) { + shapeSize = shapeSize * dim; + } + if (shapeSize < mWeightSize) { + MS_LOG(INFO) << "shapeSize Invalid!" << shapeSize; + return false; + } - return true; + return true; } void CalFakeNode(const AnfNodePtr &inTensor) { @@ -190,56 +191,119 @@ void CalFakeNode(const AnfNodePtr &inTensor) { // } } -STATUS CalQuantizationParams(std::unique_ptr &quantParam, double mMin, - double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) { - MS_ASSERT(quantParam != nullptr); - if (mMin > 0.0f) { - MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; - mMin = 0.0f; - } - if (mMax < 0.0f) { - MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; - mMax = 0.0f; - } - if (mMin > mMax) { - MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; - return RET_PARAM_INVALID; - } - if (mMin == mMax) { - if (mMin != 0.0f) { - MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; - return RET_ERROR; - } - quantParam->inited = true; - quantParam->min = mMin; - quantParam->max = mMax; - quantParam->scale = 0.0f; - quantParam->zeroPoint = 0; - quantParam->narrowRange = narrowRange; - quantParam->numBits = num_bits; - return RET_OK; +STATUS CalQuantizationParams(std::unique_ptr &quantParam, double mMin, double mMax, bool narrowRange, + int quant_max, int quant_min, int num_bits) { + MS_ASSERT(quantParam != nullptr); + if (mMin > 0.0f) { + MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; + mMin = 0.0f; + } + if (mMax < 0.0f) { + MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; + mMax = 0.0f; + } + if (mMin > mMax) { + MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; + return RET_PARAM_INVALID; + } + if (mMin == mMax) { + if (mMin != 0.0f) { + MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; + return RET_ERROR; } - - auto quantMinFloat = static_cast(quant_min); - auto quantMaxFloat = static_cast(quant_max); - double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); - const double zeroPointFromMin = quantMinFloat - mMin / scale; - // const double zeroPointFromMax = quantMaxFloat - mMax / scale; - int zeroPoint = static_cast(std::round(zeroPointFromMin)); - - // The zero point should always be in the range of quantized value, - // [qmin, qmax]. - MS_ASSERT(zeroPoint >= quantMin); - MS_ASSERT(zeroPoint <= quantMax); quantParam->inited = true; quantParam->min = mMin; quantParam->max = mMax; - quantParam->scale = scale; - quantParam->zeroPoint = zeroPoint; + quantParam->scale = 0.0f; + quantParam->zeroPoint = 0; quantParam->narrowRange = narrowRange; quantParam->numBits = num_bits; + return RET_OK; + } + + auto quantMinFloat = static_cast(quant_min); + auto quantMaxFloat = static_cast(quant_max); + double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); + const double zeroPointFromMin = quantMinFloat - mMin / scale; + // const double zeroPointFromMax = quantMaxFloat - mMax / scale; + int zeroPoint = static_cast(std::round(zeroPointFromMin)); + + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + MS_ASSERT(zeroPoint >= quantMin); + MS_ASSERT(zeroPoint <= quantMax); + quantParam->inited = true; + quantParam->min = mMin; + quantParam->max = mMax; + quantParam->scale = scale; + quantParam->zeroPoint = zeroPoint; + quantParam->narrowRange = narrowRange; + quantParam->numBits = num_bits; + + return RET_OK; +} +STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, + bool narrowRange, int numBits) { + MS_ASSERT(quantParam != nullptr); + if (mMin > 0.0f) { + MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; + mMin = 0.0f; + } + if (mMax < 0.0f) { + MS_LOG(ERROR) << "mMax " << mMax << " is smaller than 0, set to 0, this may course low precision"; + mMax = 0.0f; + } + if (mMin > mMax) { + MS_LOG(ERROR) << "cal error while min" << mMin << ">" << mMax; + return RET_PARAM_INVALID; + } + if (mMin == mMax) { + if (mMin != 0.0f) { + MS_LOG(ERROR) << "min and max should both be zero if they are equal to each other"; + return RET_ERROR; + } + quantParam->inited = true; + quantParam->min = mMin; + quantParam->max = mMax; + quantParam->scale = 0.0f; + quantParam->zeroPoint = 0; + quantParam->narrowRange = narrowRange; + quantParam->numBits = numBits; return RET_OK; + } + + int quantMin = narrowRange ? 1 : 0; + int quantMax = (1 << (unsigned int)numBits) - 1; + auto quantMinFloat = static_cast(quantMin); + auto quantMaxFloat = static_cast(quantMax); + double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); + const double zeroPointFromMin = quantMinFloat - mMin / scale; + const double zeroPointFromMax = quantMaxFloat - mMax / scale; + const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale); + const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale); + const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax; + int zeroPoint; + if (zpDouble < quantMinFloat) { + zeroPoint = quantMin; + } else if (zpDouble > quantMaxFloat) { + zeroPoint = quantMax; + } else { + zeroPoint = static_cast(std::round(zpDouble)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + MS_ASSERT(zeroPoint >= quantMin); + MS_ASSERT(zeroPoint <= quantMax); + quantParam->inited = true; + quantParam->min = mMin; + quantParam->max = mMax; + quantParam->scale = scale; + quantParam->zeroPoint = zeroPoint; + quantParam->narrowRange = narrowRange; + quantParam->numBits = numBits; + + return RET_OK; } STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, @@ -292,14 +356,14 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ weightPtr->set_quant_param(quantParam); } - auto ret = memcpy_s(const_cast(rawDatas), weightPtr->tensor_size(), - qDatas.data(), shapeSize * sizeof(int8_t)); + auto ret = + memcpy_s(const_cast(rawDatas), weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy error: " << ret; return RET_ERROR; } if (quantType == QuantType_WeightQuant) { - PostBitPack(const_cast(rawDatas), shapeSize, bitNum); + PostBitPack(const_cast(rawDatas), shapeSize, bitNum); } weightPtr->set_tensor_type(kNumberTypeInt8); @@ -338,14 +402,13 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ qDatas[i] = quant_max; } else if (quant_data < quant_min) { qDatas[i] = quant_min; - } else { + } else { qDatas[i] = static_cast(quant_data); } } weightPtr->set_quant_param(quantParam); - auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(), - qDatas.data(), shapeSize * sizeof(int8_t)); + auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(), qDatas.data(), shapeSize * sizeof(int8_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy error: " << ret; return RET_ERROR; @@ -358,34 +421,32 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); } - - return RET_OK; + return RET_OK; } STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { - auto *rawDatas = reinterpret_cast(weight); - vector qDatas(rawDatas, rawDatas + shapeSize); - vector qDatas_packed; - if (bitNum < 8 && bitNum > 1) { - BitPack weight_bitpack(bitNum); - weight_bitpack.BitPacking(qDatas, qDatas_packed); - if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) { - MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; - return RET_ERROR; - } - } else if (bitNum == 8) { - if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) { - MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed"; - return RET_ERROR; - } - } else { - MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum; - return RET_ERROR; + auto *rawDatas = reinterpret_cast(weight); + vector qDatas(rawDatas, rawDatas + shapeSize); + vector qDatas_packed; + if (bitNum < 8 && bitNum > 1) { + BitPack weight_bitpack(bitNum); + weight_bitpack.BitPacking(qDatas, qDatas_packed); + if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) { + MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; + return RET_ERROR; + } + } else if (bitNum == 8) { + if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) { + MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed"; + return RET_ERROR; } + } else { + MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum; + return RET_ERROR; + } - return RET_OK; + return RET_OK; } } // namespace quant } // namespace lite } // namespace mindspore - diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index a287473458..ceb8822779 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -62,6 +62,41 @@ class QuantStrategy { STATUS CalQuantizationParams(std::unique_ptr &quantParam, double mMin, double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits); +STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, double mMax, + bool narrowRange = false, int numBits = UINT8_QUANTIZATION); + +template +T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { + MS_ASSERT(quantParam != nullptr); + MS_ASSERT(quantParam->inited); + const auto scale = quantParam->scale; + const auto zeroPoint = quantParam->zeroPoint; + const auto numBit = quantParam->numBits; + const auto narrowRange = quantParam->narrowRange; + const double maxLimit = static_cast((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; + double minLimit; + if (narrowRange) { + minLimit = static_cast(1 - zeroPoint) * scale; + } else { + minLimit = static_cast(0 - zeroPoint) * scale; + } + return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { + double tmp = 0.0f; + if (originData > maxLimit) { + tmp = maxLimit; + } else if (originData < minLimit) { + tmp = minLimit; + } else { + tmp = originData; + } + auto quantData = static_cast(std::round(tmp / scale + zeroPoint)); + if (quantData == 0 && narrowRange) { + quantData++; + } + return quantData; + }(); +} + template T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) { MS_ASSERT(quantParam != nullptr); diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.cc b/mindspore/lite/tools/converter/quantizer/quantizer.cc index 3480705c62..6613ae042c 100644 --- a/mindspore/lite/tools/converter/quantizer/quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/quantizer.cc @@ -15,22 +15,19 @@ */ #include "mindspore/lite/tools/converter/quantizer/quantizer.h" +#include "schema/inner/model_generated.h" -namespace mindspore { -namespace lite { -namespace quant { -Quantizer::Quantizer(FuncGraphPtr graph) : funcGraph(graph) { - if (funcGraph == nullptr) { - return; - } -} +namespace mindspore::lite::quant { STATUS Quantizer::GenerateQuantParam() { return RET_OK; } STATUS Quantizer::RemoveFakeQuant() { return RET_OK; } STATUS Quantizer::DetermineNodeQuantType() { return RET_OK; } -} // namespace quant -} // namespace lite -} // namespace mindspore +STATUS FbQuantizer::GenerateQuantParam() { return RET_OK; } + +STATUS FbQuantizer::RemoveFakeQuant() { return RET_OK; } + +STATUS FbQuantizer::DetermineNodeQuantType() { return RET_OK; } +} // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.h b/mindspore/lite/tools/converter/quantizer/quantizer.h index 1cbd6f26cc..741c8f95cb 100644 --- a/mindspore/lite/tools/converter/quantizer/quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/quantizer.h @@ -18,48 +18,63 @@ #define MS_QUANTIZER_H #include +#include +#include #include "include/errorcode.h" #include "ir/func_graph.h" #include "ir/anf.h" -#include "include/model.h" #include "base/base.h" #include "src/param_value_lite.h" +#include "schema/inner/model_generated.h" #include "tools/converter/converter_flags.h" -namespace mindspore { -namespace lite { -namespace quant { +namespace mindspore::lite::quant { using STATUS = int; enum QuantType { - QuantType_QUANT_NONE = 0, - QuantType_AwareTraining = 1, - QuantType_WeightQuant = 2, - QuantType_PostTraining = 3, - QuantType_MIN = QuantType_QUANT_NONE, - QuantType_MAX = QuantType_PostTraining + QuantType_QUANT_NONE = 0, + QuantType_AwareTraining = 1, + QuantType_WeightQuant = 2, + QuantType_PostTraining = 3, + QuantType_MIN = QuantType_QUANT_NONE, + QuantType_MAX = QuantType_PostTraining }; class Quantizer { public: - explicit Quantizer(FuncGraphPtr graph); + explicit Quantizer(FuncGraphPtr graph) : funcGraph(std::move(graph)) {} - ~Quantizer() = default; + ~Quantizer() = default; - virtual STATUS RemoveFakeQuant(); + virtual STATUS RemoveFakeQuant(); - virtual STATUS GenerateQuantParam(); + virtual STATUS GenerateQuantParam(); - virtual STATUS DetermineNodeQuantType(); + virtual STATUS DetermineNodeQuantType(); - virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0; + virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0; mindspore::lite::converter::Flags flags; protected: - FuncGraphPtr funcGraph = nullptr; + FuncGraphPtr funcGraph = nullptr; }; -} // namespace quant -} // namespace lite -} // namespace mindspore -#endif +class FbQuantizer { + public: + explicit FbQuantizer(schema::MetaGraphT *graph) : graph(graph) {} + + ~FbQuantizer() = default; + + virtual STATUS RemoveFakeQuant(); + + virtual STATUS GenerateQuantParam(); + + virtual STATUS DetermineNodeQuantType(); + virtual STATUS DoQuantize() = 0; + + protected: + std::shared_ptr graph = nullptr; +}; +} // namespace mindspore::lite::quant + +#endif