|
|
|
@ -190,7 +190,7 @@ void CalFakeNode(const AnfNodePtr &inTensor) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin,
|
|
|
|
|
double mMax, bool narrowRange, int numBits) {
|
|
|
|
|
double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) {
|
|
|
|
|
MS_ASSERT(quantParam != nullptr);
|
|
|
|
|
if (mMin > 0.0f) {
|
|
|
|
|
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
|
|
|
|
@ -215,28 +215,17 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
|
|
|
|
|
quantParam->scale = 0.0f;
|
|
|
|
|
quantParam->zeroPoint = 0;
|
|
|
|
|
quantParam->narrowRange = narrowRange;
|
|
|
|
|
quantParam->numBits = numBits;
|
|
|
|
|
quantParam->numBits = num_bits;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int quantMin = narrowRange ? 1 : 0;
|
|
|
|
|
int quantMax = (1 << (unsigned int)numBits) - 1;
|
|
|
|
|
auto quantMinFloat = static_cast<double>(quantMin);
|
|
|
|
|
auto quantMaxFloat = static_cast<double>(quantMax);
|
|
|
|
|
auto quantMinFloat = static_cast<double>(quant_min);
|
|
|
|
|
auto quantMaxFloat = static_cast<double>(quant_max);
|
|
|
|
|
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
|
|
|
|
|
const double zeroPointFromMin = quantMinFloat - mMin / scale;
|
|
|
|
|
const double zeroPointFromMax = quantMaxFloat - mMax / scale;
|
|
|
|
|
const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale);
|
|
|
|
|
const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale);
|
|
|
|
|
const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax;
|
|
|
|
|
int zeroPoint;
|
|
|
|
|
if (zpDouble < quantMinFloat) {
|
|
|
|
|
zeroPoint = quantMin;
|
|
|
|
|
} else if (zpDouble > quantMaxFloat) {
|
|
|
|
|
zeroPoint = quantMax;
|
|
|
|
|
} else {
|
|
|
|
|
zeroPoint = static_cast<int32_t>(std::round(zpDouble));
|
|
|
|
|
}
|
|
|
|
|
// const double zeroPointFromMax = quantMaxFloat - mMax / scale;
|
|
|
|
|
int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));
|
|
|
|
|
|
|
|
|
|
// The zero point should always be in the range of quantized value,
|
|
|
|
|
// [qmin, qmax].
|
|
|
|
|
MS_ASSERT(zeroPoint >= quantMin);
|
|
|
|
@ -247,12 +236,12 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
|
|
|
|
|
quantParam->scale = scale;
|
|
|
|
|
quantParam->zeroPoint = zeroPoint;
|
|
|
|
|
quantParam->narrowRange = narrowRange;
|
|
|
|
|
quantParam->numBits = numBits;
|
|
|
|
|
quantParam->numBits = num_bits;
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bitNum) {
|
|
|
|
|
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum) {
|
|
|
|
|
auto dims = weightPtr->tensor_shape();
|
|
|
|
|
if (dims.size() < 1) {
|
|
|
|
|
MS_LOG(ERROR) << "weight dims size error";
|
|
|
|
@ -284,7 +273,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
|
|
|
|
|
STATUS status = CalQuantizationParams(quantParam, min, max, false, bitNum);
|
|
|
|
|
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
|
|
|
|
return status;
|
|
|
|
@ -308,8 +297,8 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
|
|
|
|
|
PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
weightPtr->set_tensor_type(kNumberTypeUInt8);
|
|
|
|
|
weightPtr->set_tensor_size(shapeSize * sizeof(uint8_t));
|
|
|
|
|
weightPtr->set_tensor_type(kNumberTypeInt8);
|
|
|
|
|
weightPtr->set_tensor_size(shapeSize * sizeof(int8_t));
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|