diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index b0f2ec605f..bf88c5ddcc 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -29,12 +29,12 @@ Flags::Flags() { AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", ""); AddFlag(&Flags::weightFile, "weightFile", "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); - AddFlag(&Flags::inferenceType, "inferenceType", - "Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT"); - AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | WeightQuant | PostTraining", ""); - AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT"); + AddFlag(&Flags::inferenceTypeIn, "inferenceType", + "Real data type saved in output file, reserved param, NOT used for now. FLOAT | INT8", "FLOAT"); + AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | PostTraining", ""); + AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT"); AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); - AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127"); + AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5"); AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0"); AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", ""); AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true"); @@ -77,14 +77,24 @@ int Flags::Init(int argc, const char **argv) { } if (this->inputInferenceTypeIn == "FLOAT") { this->inputInferenceType = TypeId::kNumberTypeFloat; - } else if (this->inputInferenceTypeIn == "UINT8") { - this->inputInferenceType = TypeId::kNumberTypeUInt8; } else if (this->inputInferenceTypeIn == "INT8") { this->inputInferenceType = TypeId::kNumberTypeInt8; } else { - std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str(); + std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s, supported inputInferenceType: FLOAT | INT8", + this->inputInferenceTypeIn.c_str(); return 1; } + + if (this->inferenceTypeIn == "FLOAT") { + this->inferenceType = TypeId::kNumberTypeFloat; + } else if (this->inferenceTypeIn == "INT8") { + this->inferenceType = TypeId::kNumberTypeInt8; + } else { + std::cerr << "INPUT INVALID: inferenceType is invalid: %s, supported inferenceType: FLOAT | INT8", + this->inferenceTypeIn.c_str(); + return 1; + } + if (this->fmkIn == "CAFFE") { this->fmk = FmkType_CAFFE; } else if (this->fmkIn == "MS") { diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 98e3581d78..1067a8cb09 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -63,10 +63,10 @@ class Flags : public virtual mindspore::lite::FlagParser { // used for quantization std::string quantTypeIn; QuantType quantType; - std::string inferenceType; + std::string inferenceTypeIn; + TypeId inferenceType = TypeId::kNumberTypeFloat; // used for parse aware trainning std::string inputInferenceTypeIn; - // mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT; TypeId inputInferenceType = TypeId::kNumberTypeFloat; std::string stdDev; std::string mean; diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 47dd18de33..b74d9b57e8 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -194,6 +194,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { return RET_ERROR; } dTypeTransPass->SetInputDataDType(ctx.inputInferenceType); + dTypeTransPass->SetOutputDataDType(ctx.inferenceType); quantNodeOptimizer.AddPass(dTypeTransPass); quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index 24e4ccec33..d8fbdd846f 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -101,7 +101,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); - if (inputDataDType == TypeId::kNumberTypeInt8) { + if (outputDataDType == TypeId::kNumberTypeInt8) { return RET_OK; } MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat); @@ -231,5 +231,8 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte } void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } + +void DTypeTransPass::SetOutputDataDType(TypeId dataType) { this->outputDataDType = dataType; } + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h index 2b1906b6fe..a3fc5490a3 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -38,6 +38,8 @@ class DTypeTransPass : public GraphPass { void SetInputDataDType(TypeId dataType); + void SetOutputDataDType(TypeId dataType); + private: STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); @@ -51,6 +53,7 @@ class DTypeTransPass : public GraphPass { private: size_t id; TypeId inputDataDType = TypeId::kNumberTypeFloat; + TypeId outputDataDType = TypeId::kNumberTypeFloat; OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr { std::unique_ptr newCNode(new (std::nothrow) schema::CNodeT); diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 2a55aedf93..c921add7d3 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -88,7 +88,7 @@ AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, if (inputInferType == "FLOAT") { inArr.reset(new (std::nothrow) InputArray(mean, stdValue)); } else { - inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeUInt8)); + inArr.reset(new (std::nothrow) InputArray(mean, stdValue, TypeId::kNumberTypeInt8)); } mInputArray = inArr.get(); mInputArray->InitQuantParam(); diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h index 574441701d..a9f046a47f 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.h @@ -37,8 +37,8 @@ struct InputArray { InputArray(float mean, float stdDev, TypeId dataType = TypeId::kNumberTypeFloat) { this->dataType = dataType; - constexpr float qmin = 0; - constexpr float qmax = 255; + constexpr float qmin = -128; + constexpr float qmax = 127; mMin = (qmin - mean) / stdDev; mMax = (qmax - mean) / stdDev; } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 6753973ac7..4c262792bb 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -246,8 +246,8 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl return RET_OK; } - int quantMin = narrowRange ? 1 : 0; - int quantMax = (1 << (unsigned int) numBits) - 1; + int quantMin = narrowRange ? 1 : 0 - 128; + int quantMax = (1 << (unsigned int) numBits) - 1 - 128; auto quantMinFloat = static_cast(quantMin); auto quantMaxFloat = static_cast(quantMax); double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);