From c1dd0b8e3d7b51bcc9558a02352ad0a8a3f41ddf Mon Sep 17 00:00:00 2001 From: xuanyue Date: Mon, 14 Dec 2020 14:17:29 +0800 Subject: [PATCH] fix onnx models convert to quantized bug --- .../converter/parser/onnx/onnx_model_parser.cc | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 70cdcd2f90..562f5d8ee1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -158,16 +158,16 @@ STATUS OnnxModelParser::ConvertNodes() { status = RET_ERROR; continue; } - if (IsSpecialOnnxNode(onnx_node)) { - auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c); - status = status == RET_OK ? status_node : status; - continue; - } status = ConvertOpQuantParams(onnx_node, primitive_c); if (status != RET_OK) { MS_LOG(ERROR) << "convert " << onnx_node.op_type() << " quant param failed."; continue; } + if (IsSpecialOnnxNode(onnx_node)) { + auto status_node = ConvertSpecialOnnxNode(onnx_node, primitive_c); + status = status == RET_OK ? status_node : status; + continue; + } // build CNode status = BuildCNode(onnx_node, primitive_c); if (status != RET_OK) { @@ -512,8 +512,10 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite return RET_ERROR; } else { op_inputs.push_back(nodes_[onnx_node.input(i)]); + prim_ptr->AddInputQuantParam(primitive_c->input_quant_params().at(i)); } } + prim_ptr->AddOutputQuantParam(std::vector(1)); auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); new_cnode->set_fullname_with_scope("Gemm_MatMul_" + onnx_node.output(0)); new_cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); @@ -526,6 +528,9 @@ STATUS OnnxModelParser::BuildCNodeForGemm(const onnx::NodeProto &onnx_node, lite } op_inputs.push_back(nodes_["Gemm_MatMul_" + onnx_node.output(0)]); op_inputs.push_back(nodes_[onnx_node.input(2)]); + prim_ptr->AddInputQuantParam(std::vector(1)); + prim_ptr->AddInputQuantParam(primitive_c->input_quant_params().at(2)); + prim_ptr->AddOutputQuantParam(primitive_c->output_quant_params().front()); auto new_cnode = func_graph_ptr_->NewCNode(prim_ptr, op_inputs); new_cnode->set_fullname_with_scope("Gemm_BiasAdd_" + onnx_node.output(0)); new_cnode->set_abstract(std::make_shared(type_ptr, shape_vector));