|
|
|
@ -27,6 +27,33 @@ using std::vector;
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
namespace quant {
|
|
|
|
|
bool WeightQuantizer::IsPosNum(const std::string &str) {
|
|
|
|
|
for (size_t i = 0; i < str.size(); i++) {
|
|
|
|
|
if (str.at(i) < '0' || str.at(i) > '9') {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (str.at(i) == '0' && i == 0 && str.size() != 1) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
|
|
|
|
|
MS_ASSERT(config != nullptr);
|
|
|
|
|
if (!WeightQuantizer::IsPosNum(config->convWeightQuantChannelThreshold)) {
|
|
|
|
|
MS_LOG(ERROR) << "convWeightQuantChannelThreshold must be valid pos num.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!WeightQuantizer::IsPosNum(config->quantSize)) {
|
|
|
|
|
MS_LOG(ERROR) << "quantSize must be valid pos num.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!WeightQuantizer::IsPosNum(config->bitNum) || config->bitNum != "8") {
|
|
|
|
|
MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 bit weight quant.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize,
|
|
|
|
|
const std::string &convWeightChannelThreshold, const std::string &bitNum)
|
|
|
|
|
: Quantizer(graph) {
|
|
|
|
|