!4484 fix perlayer quantization

Merge pull request !4484 from xutianchun/post_quant
pull/4484/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6d549417f0

@ -201,14 +201,16 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
}
auto activate_index = node->inputIndex[i];
auto tensor_input = metaGraphT->allTensors[activate_index].get();
std::unique_ptr<schema::QuantParamT> input_quant_param =
std::make_unique<schema::QuantParamT>(input_quant_params[i]);
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale
<< " zp: " << input_quant_param->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) {
tensor_input->dataType = kNumberTypeInt8;
if (tensor_input->quantParams.empty()) {
std::unique_ptr<schema::QuantParamT> input_quant_param =
std::make_unique<schema::QuantParamT>(input_quant_params[i]);
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale
<< " zp: " << input_quant_param->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) {
tensor_input->dataType = kNumberTypeInt8;
}
}
}
@ -219,11 +221,13 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
if (output_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
} else {
std::unique_ptr<schema::QuantParamT> output_quant_param =
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale
<< " zp: " << output_quant_param->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
if (tensor_output->quantParams.empty()) {
std::unique_ptr<schema::QuantParamT> output_quant_param =
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale
<< " zp: " << output_quant_param->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
}
}
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) {

@ -62,7 +62,7 @@ struct DivergInfo {
this->bin_num = bins;
this->bit_num = bits;
histogram.resize(bin_num);
max = FLT_MIN;
max = -FLT_MAX;
min = FLT_MAX;
this->quant_max = quant_max;
this->quant_min = quant_min;

@ -313,7 +313,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
per_channel = false;
} else {
uint32_t channels = dims[3];
uint32_t channels = dims[0];
if (channels == 0) {
MS_LOG(ERROR) << "channels is 0";
return RET_ERROR;
@ -325,7 +325,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
size_t shapeSize = weightPtr->tensor_shape_size();
auto channels = dims[3];
auto channels = dims[0];
size_t oneFilterSize = shapeSize / channels;
auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr());
if (rawDatas == nullptr) {
@ -334,15 +334,20 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
}
float min = FLT_MAX;
float max = FLT_MIN;
float max = -FLT_MAX;
weightPtr->quant_param().clear();
vector<int8_t> qDatas(shapeSize);
for (uint32_t i = 0; i < channels; i++) {
// find min and max
for (uint32_t j = 0; j < oneFilterSize; j++) {
min = std::min(min, rawDatas[i + j * oneFilterSize]);
max = std::max(max, rawDatas[i + j * oneFilterSize]);
auto index = j + i * channels;
if (index >= shapeSize) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
min = std::min(min, rawDatas[index]);
max = std::max(max, rawDatas[index]);
}
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
@ -350,11 +355,16 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
// update data and datatype
// do quantization
for (uint32_t j = 0; j < oneFilterSize; j++) {
float rawData = rawDatas[i + j * oneFilterSize];
auto index = j + i * channels;
if (index >= shapeSize) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
float rawData = rawDatas[index];
auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min);
qDatas[i + j * oneFilterSize] = qData;
qDatas[index] = qData;
}
weightPtr->set_quant_param(quantParam);
}

Loading…
Cancel
Save