|
|
|
@ -313,7 +313,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
|
|
|
|
|
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
|
|
|
|
|
per_channel = false;
|
|
|
|
|
} else {
|
|
|
|
|
uint32_t channels = dims[3];
|
|
|
|
|
uint32_t channels = dims[0];
|
|
|
|
|
if (channels == 0) {
|
|
|
|
|
MS_LOG(ERROR) << "channels is 0";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -325,7 +325,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
|
|
|
|
|
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
|
|
|
|
|
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
|
|
|
|
|
size_t shapeSize = weightPtr->tensor_shape_size();
|
|
|
|
|
auto channels = dims[3];
|
|
|
|
|
auto channels = dims[0];
|
|
|
|
|
size_t oneFilterSize = shapeSize / channels;
|
|
|
|
|
auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr());
|
|
|
|
|
if (rawDatas == nullptr) {
|
|
|
|
@ -334,15 +334,20 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float min = FLT_MAX;
|
|
|
|
|
float max = FLT_MIN;
|
|
|
|
|
float max = -FLT_MAX;
|
|
|
|
|
weightPtr->quant_param().clear();
|
|
|
|
|
vector<int8_t> qDatas(shapeSize);
|
|
|
|
|
|
|
|
|
|
for (uint32_t i = 0; i < channels; i++) {
|
|
|
|
|
// find min and max
|
|
|
|
|
for (uint32_t j = 0; j < oneFilterSize; j++) {
|
|
|
|
|
min = std::min(min, rawDatas[i + j * oneFilterSize]);
|
|
|
|
|
max = std::max(max, rawDatas[i + j * oneFilterSize]);
|
|
|
|
|
auto index = j + i * channels;
|
|
|
|
|
if (index >= shapeSize) {
|
|
|
|
|
MS_LOG(ERROR) << "over flow!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
min = std::min(min, rawDatas[index]);
|
|
|
|
|
max = std::max(max, rawDatas[index]);
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
|
|
|
|
|
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
|
|
|
|
@ -350,11 +355,16 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
|
|
|
|
|
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
// update data and datatype
|
|
|
|
|
// do quantization
|
|
|
|
|
for (uint32_t j = 0; j < oneFilterSize; j++) {
|
|
|
|
|
float rawData = rawDatas[i + j * oneFilterSize];
|
|
|
|
|
auto index = j + i * channels;
|
|
|
|
|
if (index >= shapeSize) {
|
|
|
|
|
MS_LOG(ERROR) << "over flow!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
float rawData = rawDatas[index];
|
|
|
|
|
auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min);
|
|
|
|
|
qDatas[i + j * oneFilterSize] = qData;
|
|
|
|
|
qDatas[index] = qData;
|
|
|
|
|
}
|
|
|
|
|
weightPtr->set_quant_param(quantParam);
|
|
|
|
|
}
|
|
|
|
|