fix post quantization

pull/4484/head
xutianchun 5 years ago
parent 5ca5c346bb
commit 252c13fedd

@ -201,14 +201,16 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
} }
auto activate_index = node->inputIndex[i]; auto activate_index = node->inputIndex[i];
auto tensor_input = metaGraphT->allTensors[activate_index].get(); auto tensor_input = metaGraphT->allTensors[activate_index].get();
std::unique_ptr<schema::QuantParamT> input_quant_param = if (tensor_input->quantParams.empty()) {
std::make_unique<schema::QuantParamT>(input_quant_params[i]); std::unique_ptr<schema::QuantParamT> input_quant_param =
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale std::make_unique<schema::QuantParamT>(input_quant_params[i]);
<< " zp: " << input_quant_param->zeroPoint; MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale
tensor_input->quantParams.emplace_back(std::move(input_quant_param)); << " zp: " << input_quant_param->zeroPoint;
if (!(node_type == schema::PrimitiveType_QuantDTypeCast && tensor_input->quantParams.emplace_back(std::move(input_quant_param));
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) { if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
tensor_input->dataType = kNumberTypeInt8; primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) {
tensor_input->dataType = kNumberTypeInt8;
}
} }
} }
@ -219,11 +221,13 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
if (output_quant_params.empty()) { if (output_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
} else { } else {
std::unique_ptr<schema::QuantParamT> output_quant_param = if (tensor_output->quantParams.empty()) {
std::make_unique<schema::QuantParamT>(output_quant_params[0]); std::unique_ptr<schema::QuantParamT> output_quant_param =
MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale std::make_unique<schema::QuantParamT>(output_quant_params[0]);
<< " zp: " << output_quant_param->zeroPoint; MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale
tensor_output->quantParams.emplace_back(std::move(output_quant_param)); << " zp: " << output_quant_param->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
}
} }
if (!(node_type == schema::PrimitiveType_QuantDTypeCast && if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) {

@ -62,7 +62,7 @@ struct DivergInfo {
this->bin_num = bins; this->bin_num = bins;
this->bit_num = bits; this->bit_num = bits;
histogram.resize(bin_num); histogram.resize(bin_num);
max = FLT_MIN; max = -FLT_MAX;
min = FLT_MAX; min = FLT_MAX;
this->quant_max = quant_max; this->quant_max = quant_max;
this->quant_min = quant_min; this->quant_min = quant_min;

@ -313,7 +313,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
per_channel = false; per_channel = false;
} else { } else {
uint32_t channels = dims[3]; uint32_t channels = dims[0];
if (channels == 0) { if (channels == 0) {
MS_LOG(ERROR) << "channels is 0"; MS_LOG(ERROR) << "channels is 0";
return RET_ERROR; return RET_ERROR;
@ -325,7 +325,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
size_t shapeSize = weightPtr->tensor_shape_size(); size_t shapeSize = weightPtr->tensor_shape_size();
auto channels = dims[3]; auto channels = dims[0];
size_t oneFilterSize = shapeSize / channels; size_t oneFilterSize = shapeSize / channels;
auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr()); auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr());
if (rawDatas == nullptr) { if (rawDatas == nullptr) {
@ -334,15 +334,20 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
} }
float min = FLT_MAX; float min = FLT_MAX;
float max = FLT_MIN; float max = -FLT_MAX;
weightPtr->quant_param().clear(); weightPtr->quant_param().clear();
vector<int8_t> qDatas(shapeSize); vector<int8_t> qDatas(shapeSize);
for (uint32_t i = 0; i < channels; i++) { for (uint32_t i = 0; i < channels; i++) {
// find min and max // find min and max
for (uint32_t j = 0; j < oneFilterSize; j++) { for (uint32_t j = 0; j < oneFilterSize; j++) {
min = std::min(min, rawDatas[i + j * oneFilterSize]); auto index = j + i * channels;
max = std::max(max, rawDatas[i + j * oneFilterSize]); if (index >= shapeSize) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
min = std::min(min, rawDatas[index]);
max = std::max(max, rawDatas[index]);
} }
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
@ -350,11 +355,16 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
MS_LOG(ERROR) << "CalQuantizationParams failed" << status; MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status; return status;
} }
// update data and datatype // do quantization
for (uint32_t j = 0; j < oneFilterSize; j++) { for (uint32_t j = 0; j < oneFilterSize; j++) {
float rawData = rawDatas[i + j * oneFilterSize]; auto index = j + i * channels;
if (index >= shapeSize) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
float rawData = rawDatas[index];
auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min); auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min);
qDatas[i + j * oneFilterSize] = qData; qDatas[index] = qData;
} }
weightPtr->set_quant_param(quantParam); weightPtr->set_quant_param(quantParam);
} }

Loading…
Cancel
Save