add weight quant input check

pull/6487/head
kai00 4 years ago
parent 076d8ae530
commit 3b4c36223b

@ -80,9 +80,8 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return nullptr;
}
} else if (config->quantType == schema::QuantType_WeightQuant) {
auto bitNum = static_cast<size_t>(std::stoull(config->bitNum));
if (bitNum != quant::UINT8_QUANTIZATION) {
MS_LOG(ERROR) << "Current Only Support 8 bit weight quant";
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) {
MS_LOG(ERROR) << "weight quant input param error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}

@ -124,7 +124,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
auto dims = weight->tensor_shape();
if (per_channel) {
if (dims.size() != 4) {
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
MS_LOG(ERROR) << "weight dims size: " << dims.size() << " switch to per-layer quant mode.";
per_channel = false;
} else {
uint32_t channels = dims[0];

@ -27,6 +27,33 @@ using std::vector;
namespace mindspore {
namespace lite {
namespace quant {
bool WeightQuantizer::IsPosNum(const std::string &str) {
for (size_t i = 0; i < str.size(); i++) {
if (str.at(i) < '0' || str.at(i) > '9') {
return false;
}
if (str.at(i) == '0' && i == 0 && str.size() != 1) {
return false;
}
}
return true;
}
STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
MS_ASSERT(config != nullptr);
if (!WeightQuantizer::IsPosNum(config->convWeightQuantChannelThreshold)) {
MS_LOG(ERROR) << "convWeightQuantChannelThreshold must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->quantSize)) {
MS_LOG(ERROR) << "quantSize must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->bitNum) || config->bitNum != "8") {
MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 bit weight quant.";
return RET_ERROR;
}
return RET_OK;
}
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize,
const std::string &convWeightChannelThreshold, const std::string &bitNum)
: Quantizer(graph) {

@ -41,6 +41,8 @@ class WeightQuantizer : public Quantizer {
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
STATUS DoConvQuantize(const std::list<CNodePtr> &nodes);
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
static STATUS WeightQuantInputCheck(const converter::Flags *config);
static bool IsPosNum(const std::string &str);
int quant_max{INT8_MAX};
int quant_min{INT8_MIN};
private:

Loading…
Cancel
Save