!5752 Fix bug of loss of quantization parameters in quantized models.

Merge pull request !5752 from wangshaocong/quant_params
pull/5752/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit a7f68f1045

@ -64,64 +64,61 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
MS_ASSERT(dst_node != nullptr);
// add quant param
dst_node->quantType = primitive->GetQuantType();
if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining
|| dst_node->quantType == schema::QuantType_WeightQuant) {
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
// activation
auto input_quant_params = primitive->GetInputQuantParams();
auto node_type = (schema::PrimitiveType)primitive->Type();
if (input_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << dst_node->name << " input quant params is empty";
return RET_OK;
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
// activation
auto input_quant_params = primitive->GetInputQuantParams();
auto node_type = (schema::PrimitiveType)primitive->Type();
if (input_quant_params.empty()) {
MS_LOG(WARNING) << "node: " << dst_node->name << " input quant params is empty";
return RET_OK;
}
for (size_t i = 0; i < input_quant_params.size(); i++) {
if (i >= dst_node->inputIndex.size()) {
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
break;
}
for (size_t i = 0; i < input_quant_params.size(); i++) {
if (i >= dst_node->inputIndex.size()) {
MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size()
<< " quant_params; but only " << dst_node->inputIndex.size() << " input";
break;
}
auto activate_index = dst_node->inputIndex[i];
auto tensor_input = meta_graph->allTensors[activate_index].get();
if (tensor_input->quantParams.empty()) {
for (auto input_quant_param : input_quant_params[i]) {
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}
auto activate_index = dst_node->inputIndex[i];
auto tensor_input = meta_graph->allTensors[activate_index].get();
if (tensor_input->quantParams.empty()) {
for (auto input_quant_param : input_quant_params[i]) {
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}
}
}
// output
auto output_index = dst_node->outputIndex[0];
auto tensor_output = meta_graph->allTensors[output_index].get();
auto output_quant_params = primitive->GetOutputQuantParams();
if (output_quant_params.empty()) {
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty";
}
} else {
for (auto output_quant_param : output_quant_params[0]) {
if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(output_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
}
// output
auto output_index = dst_node->outputIndex[0];
auto tensor_output = meta_graph->allTensors[output_index].get();
auto output_quant_params = primitive->GetOutputQuantParams();
if (output_quant_params.empty()) {
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty";
}
} else {
for (auto output_quant_param : output_quant_params[0]) {
if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(output_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
}
}
if (dst_node->quantType == schema::QuantType_PostTraining) {
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
}
if (dst_node->quantType == schema::QuantType_PostTraining) {
if (node_type != schema::PrimitiveType_QuantDTypeCast) {
tensor_output->dataType = kNumberTypeInt8;
} else {
MS_ASSERT(utils::isa<std::shared_ptr<QuantDTypeCast>>(primitive));
auto primc = utils::cast<std::shared_ptr<QuantDTypeCast>>(primitive);
MS_ASSERT(primc != nullptr);
if (primc->GetDstT() != kNumberTypeFloat32) {
tensor_output->dataType = kNumberTypeInt8;
} else {
MS_ASSERT(utils::isa<std::shared_ptr<QuantDTypeCast>>(primitive));
auto primc = utils::cast<std::shared_ptr<QuantDTypeCast>>(primitive);
MS_ASSERT(primc != nullptr);
if (primc->GetDstT() != kNumberTypeFloat32) {
tensor_output->dataType = kNumberTypeInt8;
}
}
}
}

@ -83,12 +83,15 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
auto primitiveCValue = PrimitiveC::UnPackFromSchemaPrimitiveT(cNode->primitive.release());
cNode->primitive = nullptr;
// add quant parameter
if (cNode->quantType == schema::QuantType_AwareTraining) {
if (cNode->quantType != schema::QuantType_PostTraining) {
primitiveCValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) {
if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
primitiveCValue->AddInputQuantParam(quant_params);
} else {
std::vector<schema::QuantParamT> empty_quant_params;
primitiveCValue->AddInputQuantParam(empty_quant_params);
}
}
for (int index : cNode->outputIndex) {

@ -38,27 +38,37 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
return RET_NULL_PTR;
}
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
if (in_tensor == nullptr) {
MS_LOG(ERROR) << "input tensor is null";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR;
}
attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.type = schema::PrimitiveType_Cast;
op->primitive->value.value = attr.release();
if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
} else {
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->srcT = GetTfliteDataType(in_tensor->type);
attr->dstT = GetTfliteDataType(out_tensor->type);
op->primitive->value.value = attr.release();
op->primitive->value.type = schema::PrimitiveType_Cast;
}
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);

@ -200,6 +200,24 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
FreeTensors(&input_tensors, &output_tensors);
return nullptr;
}
auto inputQuantParams = lite_primitive->GetInputQuantParams();
for (size_t m = 0; m < inputQuantParams.size(); m++) {
for (auto inputQuantParam : inputQuantParams[m]) {
lite::tensor::QuantArg quant_arg{};
quant_arg.scale = inputQuantParam.scale;
quant_arg.zeroPoint = inputQuantParam.zeroPoint;
input_tensors[m]->AddQuantParam(quant_arg);
}
}
auto outputQuantParams = lite_primitive->GetOutputQuantParams();
for (size_t m = 0; m < outputQuantParams.size(); m++) {
for (auto outputQuantParam : outputQuantParams[m]) {
lite::tensor::QuantArg quant_arg{};
quant_arg.scale = outputQuantParam.scale;
quant_arg.zeroPoint = outputQuantParam.zeroPoint;
output_tensors[m]->AddQuantParam(quant_arg);
}
}
// here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future.

Loading…
Cancel
Save