!5752 Fix bug of loss of quantization parameters in quantized models.

Merge pull request !5752 from wangshaocong/quant_params
pull/5752/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit a7f68f1045

@ -64,64 +64,61 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
MS_ASSERT(dst_node != nullptr); MS_ASSERT(dst_node != nullptr);
// add quant param // add quant param
dst_node->quantType = primitive->GetQuantType(); dst_node->quantType = primitive->GetQuantType();
if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
|| dst_node->quantType == schema::QuantType_WeightQuant) { // activation
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; auto input_quant_params = primitive->GetInputQuantParams();
// activation auto node_type = (schema::PrimitiveType)primitive->Type();
auto input_quant_params = primitive->GetInputQuantParams(); if (input_quant_params.empty()) {
auto node_type = (schema::PrimitiveType)primitive->Type(); MS_LOG(WARNING) << "node: " << dst_node->name << " input quant params is empty";
if (input_quant_params.empty()) { return RET_OK;
MS_LOG(WARNING) << "node: " << dst_node->name << " input quant params is empty"; }
return RET_OK; for (size_t i = 0; i < input_quant_params.size(); i++) {
if (i >= dst_node->inputIndex.size()) {
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
break;
} }
for (size_t i = 0; i < input_quant_params.size(); i++) { auto activate_index = dst_node->inputIndex[i];
if (i >= dst_node->inputIndex.size()) { auto tensor_input = meta_graph->allTensors[activate_index].get();
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() if (tensor_input->quantParams.empty()) {
<< " quant_params; but only " << dst_node->inputIndex.size() << " input"; for (auto input_quant_param : input_quant_params[i]) {
break; std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
} std::make_unique<schema::QuantParamT>(input_quant_param);
auto activate_index = dst_node->inputIndex[i]; MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
auto tensor_input = meta_graph->allTensors[activate_index].get(); << " zp: " << input_quant_param_ptr->zeroPoint;
if (tensor_input->quantParams.empty()) { tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
for (auto input_quant_param : input_quant_params[i]) {
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}
} }
} }
}
// output // output
auto output_index = dst_node->outputIndex[0]; auto output_index = dst_node->outputIndex[0];
auto tensor_output = meta_graph->allTensors[output_index].get(); auto tensor_output = meta_graph->allTensors[output_index].get();
auto output_quant_params = primitive->GetOutputQuantParams(); auto output_quant_params = primitive->GetOutputQuantParams();
if (output_quant_params.empty()) { if (output_quant_params.empty()) {
if (node_type != schema::PrimitiveType_QuantDTypeCast) { if (node_type != schema::PrimitiveType_QuantDTypeCast) {
MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty"; MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty";
} }
} else { } else {
for (auto output_quant_param : output_quant_params[0]) { for (auto output_quant_param : output_quant_params[0]) {
if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) { if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(output_quant_param); std::make_unique<schema::QuantParamT>(output_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint; << " zp: " << output_quant_param_ptr->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr)); tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
}
} }
} }
if (dst_node->quantType == schema::QuantType_PostTraining) { }
if (node_type != schema::PrimitiveType_QuantDTypeCast) { if (dst_node->quantType == schema::QuantType_PostTraining) {
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
tensor_output->dataType = kNumberTypeInt8;
} else {
MS_ASSERT(utils::isa<std::shared_ptr<QuantDTypeCast>>(primitive));
auto primc = utils::cast<std::shared_ptr<QuantDTypeCast>>(primitive);
MS_ASSERT(primc != nullptr);
if (primc->GetDstT() != kNumberTypeFloat32) {
tensor_output->dataType = kNumberTypeInt8; tensor_output->dataType = kNumberTypeInt8;
} else {
MS_ASSERT(utils::isa<std::shared_ptr<QuantDTypeCast>>(primitive));
auto primc = utils::cast<std::shared_ptr<QuantDTypeCast>>(primitive);
MS_ASSERT(primc != nullptr);
if (primc->GetDstT() != kNumberTypeFloat32) {
tensor_output->dataType = kNumberTypeInt8;
}
} }
} }
} }

@ -83,12 +83,15 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
auto primitiveCValue = PrimitiveC::UnPackFromSchemaPrimitiveT(cNode->primitive.release()); auto primitiveCValue = PrimitiveC::UnPackFromSchemaPrimitiveT(cNode->primitive.release());
cNode->primitive = nullptr; cNode->primitive = nullptr;
// add quant parameter // add quant parameter
if (cNode->quantType == schema::QuantType_AwareTraining) { if (cNode->quantType != schema::QuantType_PostTraining) {
primitiveCValue->SetQuantType(cNode->quantType); primitiveCValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) { for (int index : cNode->inputIndex) {
if (meta_graph_->allTensors[index]->quantParams.size() > 0) { if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
primitiveCValue->AddInputQuantParam(quant_params); primitiveCValue->AddInputQuantParam(quant_params);
} else {
std::vector<schema::QuantParamT> empty_quant_params;
primitiveCValue->AddInputQuantParam(empty_quant_params);
} }
} }
for (int index : cNode->outputIndex) { for (int index : cNode->outputIndex) {

@ -38,27 +38,37 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) { if (in_tensor == nullptr) {
MS_LOG(ERROR) << "input tensor is null"; MS_LOG(ERROR) << "input tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) { if (out_tensor == nullptr) {
MS_LOG(ERROR) << "output tensor is null"; MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->dstT = GetTfliteDataType(out_tensor->type); if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
op->primitive->value.type = schema::PrimitiveType_Cast; if (attr == nullptr) {
op->primitive->value.value = attr.release(); MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
} else {
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_Cast;
}
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);

@ -200,6 +200,24 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
FreeTensors(&input_tensors, &output_tensors); FreeTensors(&input_tensors, &output_tensors);
return nullptr; return nullptr;
} }
auto inputQuantParams = lite_primitive->GetInputQuantParams();
for (size_t m = 0; m < inputQuantParams.size(); m++) {
for (auto inputQuantParam : inputQuantParams[m]) {
lite::tensor::QuantArg quant_arg{};
quant_arg.scale = inputQuantParam.scale;
quant_arg.zeroPoint = inputQuantParam.zeroPoint;
input_tensors[m]->AddQuantParam(quant_arg);
}
}
auto outputQuantParams = lite_primitive->GetOutputQuantParams();
for (size_t m = 0; m < outputQuantParams.size(); m++) {
for (auto outputQuantParam : outputQuantParams[m]) {
lite::tensor::QuantArg quant_arg{};
quant_arg.scale = outputQuantParam.scale;
quant_arg.zeroPoint = outputQuantParam.zeroPoint;
output_tensors[m]->AddQuantParam(quant_arg);
}
}
// here, input_tensor's format need to be transposed nhwc according to fmkType, // here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D. // but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future. // Others should be added in future.

Loading…
Cancel
Save