diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 3f47e71115..6f52a97a41 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -80,9 +80,8 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver return nullptr; } } else if (config->quantType == schema::QuantType_WeightQuant) { - auto bitNum = static_cast(std::stoull(config->bitNum)); - if (bitNum != quant::UINT8_QUANTIZATION) { - MS_LOG(ERROR) << "Current Only Support 8 bit weight quant"; + if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { + MS_LOG(ERROR) << "weight quant input param error"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 1c90eb7b03..979c0c3518 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -124,7 +124,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti auto dims = weight->tensor_shape(); if (per_channel) { if (dims.size() != 4) { - MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; + MS_LOG(ERROR) << "weight dims size: " << dims.size() << " switch to per-layer quant mode."; per_channel = false; } else { uint32_t channels = dims[0]; diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index f1861be3a0..cac77c98c7 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -27,6 +27,33 @@ using std::vector; namespace mindspore { namespace lite { namespace quant { +bool WeightQuantizer::IsPosNum(const std::string &str) { + for (size_t i = 0; i < str.size(); i++) { + if (str.at(i) < '0' || str.at(i) > '9') { + return false; + } + if (str.at(i) == '0' && i == 0 && str.size() != 1) { + return false; + } + } + return true; +} +STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { + MS_ASSERT(config != nullptr); + if (!WeightQuantizer::IsPosNum(config->convWeightQuantChannelThreshold)) { + MS_LOG(ERROR) << "convWeightQuantChannelThreshold must be valid pos num."; + return RET_ERROR; + } + if (!WeightQuantizer::IsPosNum(config->quantSize)) { + MS_LOG(ERROR) << "quantSize must be valid pos num."; + return RET_ERROR; + } + if (!WeightQuantizer::IsPosNum(config->bitNum) || config->bitNum != "8") { + MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 bit weight quant."; + return RET_ERROR; + } + return RET_OK; +} WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, const std::string &convWeightChannelThreshold, const std::string &bitNum) : Quantizer(graph) { diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index 7485343873..7d0f42ff59 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -41,6 +41,8 @@ class WeightQuantizer : public Quantizer { STATUS DoQuantize(FuncGraphPtr funcGraph) override; STATUS DoConvQuantize(const std::list &nodes); STATUS DoMulQuantize(const std::list &nodes); + static STATUS WeightQuantInputCheck(const converter::Flags *config); + static bool IsPosNum(const std::string &str); int quant_max{INT8_MAX}; int quant_min{INT8_MIN}; private: