!5752 Fix bug of loss of quantization parameters in quantized models.

Merge pull request !5752 from wangshaocong/quant_params
pull/5752/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit a7f68f1045

@ -64,8 +64,6 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
MS_ASSERT(dst_node != nullptr); MS_ASSERT(dst_node != nullptr);
// add quant param // add quant param
dst_node->quantType = primitive->GetQuantType(); dst_node->quantType = primitive->GetQuantType();
if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining
|| dst_node->quantType == schema::QuantType_WeightQuant) {
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
// activation // activation
auto input_quant_params = primitive->GetInputQuantParams(); auto input_quant_params = primitive->GetInputQuantParams();
@ -124,7 +122,6 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
} }
} }
} }
}
return RET_OK; return RET_OK;
} }

@ -83,12 +83,15 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
auto primitiveCValue = PrimitiveC::UnPackFromSchemaPrimitiveT(cNode->primitive.release()); auto primitiveCValue = PrimitiveC::UnPackFromSchemaPrimitiveT(cNode->primitive.release());
cNode->primitive = nullptr; cNode->primitive = nullptr;
// add quant parameter // add quant parameter
if (cNode->quantType == schema::QuantType_AwareTraining) { if (cNode->quantType != schema::QuantType_PostTraining) {
primitiveCValue->SetQuantType(cNode->quantType); primitiveCValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) { for (int index : cNode->inputIndex) {
if (meta_graph_->allTensors[index]->quantParams.size() > 0) { if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
primitiveCValue->AddInputQuantParam(quant_params); primitiveCValue->AddInputQuantParam(quant_params);
} else {
std::vector<schema::QuantParamT> empty_quant_params;
primitiveCValue->AddInputQuantParam(empty_quant_params);
} }
} }
for (int index : cNode->outputIndex) { for (int index : cNode->outputIndex) {

@ -38,27 +38,37 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
return RET_NULL_PTR; return RET_NULL_PTR;
} }
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) { if (in_tensor == nullptr) {
MS_LOG(ERROR) << "input tensor is null"; MS_LOG(ERROR) << "input tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) { if (out_tensor == nullptr) {
MS_LOG(ERROR) << "output tensor is null"; MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
} else {
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type); attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_Cast;
}
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);

@ -200,6 +200,24 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
FreeTensors(&input_tensors, &output_tensors); FreeTensors(&input_tensors, &output_tensors);
return nullptr; return nullptr;
} }
auto inputQuantParams = lite_primitive->GetInputQuantParams();
for (size_t m = 0; m < inputQuantParams.size(); m++) {
for (auto inputQuantParam : inputQuantParams[m]) {
lite::tensor::QuantArg quant_arg{};
quant_arg.scale = inputQuantParam.scale;
quant_arg.zeroPoint = inputQuantParam.zeroPoint;
input_tensors[m]->AddQuantParam(quant_arg);
}
}
auto outputQuantParams = lite_primitive->GetOutputQuantParams();
for (size_t m = 0; m < outputQuantParams.size(); m++) {
for (auto outputQuantParam : outputQuantParams[m]) {
lite::tensor::QuantArg quant_arg{};
quant_arg.scale = outputQuantParam.scale;
quant_arg.zeroPoint = outputQuantParam.zeroPoint;
output_tensors[m]->AddQuantParam(quant_arg);
}
}
// here, input_tensor's format need to be transposed nhwc according to fmkType, // here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D. // but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future. // Others should be added in future.

Loading…
Cancel
Save