From 2366442c971d1d15f287e581661a48e9f74afe97 Mon Sep 17 00:00:00 2001 From: "public (04f0281d2d30)" Date: Sat, 7 Nov 2020 11:16:19 +0800 Subject: [PATCH] Fix bug of onnx model parser. --- .../legacy_optimizer/fusion/fusion_pass.cc | 11 +++++- .../converter/parser/onnx/onnx_conv_parser.cc | 2 +- .../parser/onnx/onnx_model_parser.cc | 34 +++++++++++++---- .../converter/parser/onnx/onnx_model_parser.h | 4 +- .../converter/parser/onnx/onnx_node_parser.h | 2 +- ...node_parser.cc => onnx_quantize_parser.cc} | 37 +++++++++---------- ...l_node_parser.h => onnx_quantize_parser.h} | 12 +++--- .../graph/weight_format_hardcode_pass.cc | 6 ++- 8 files changed, 69 insertions(+), 39 deletions(-) rename mindspore/lite/tools/converter/parser/onnx/{onnx_unuseful_node_parser.cc => onnx_quantize_parser.cc} (57%) rename mindspore/lite/tools/converter/parser/onnx/{onnx_unuseful_node_parser.h => onnx_quantize_parser.h} (71%) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc index e978d49878..15d95af7c0 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/fusion_pass.cc @@ -29,6 +29,7 @@ #include "tools/common/graph_util.h" #include "include/errorcode.h" #include "schema/inner/model_generated.h" +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { @@ -263,7 +264,15 @@ bool FusionPass::MatchTree(schema::MetaGraphT *graph, size_t nodeIdx, const std: return true; } for (auto preNodeIdx : preNodeIdxes) { - MS_ASSERT(subGraph->nodes.size() > preNodeIdx); + MS_ASSERT(graph->nodes.size() > preNodeIdx); + // Case of multiple outputs is not supported. + if (GetInputNodeIdx(*graph, preNodeIdx).size() > kDoubleNum || + GetOutputNodeIdx(*graph, preNodeIdx).size() > kSingleNum) { + sinkIdes.erase((sinkIdes.end() - 1)); + pathSinkIdes.erase((pathSinkIdes.end() - 1)); + target->UnSetPath(); + return false; + } // match left if (MatchTree(graph, preNodeIdx, target->left, sinkIdes, pathSinkIdes)) { // match right diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc index c8e6293882..54d197baea 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -75,6 +75,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod attr->dilateW = 1; attr->group = 1; attr->padMode = schema::PadMode_NOTSET; + attr->format = schema::Format::Format_NCHW; // set opdef each attr params for (const auto &onnx_node_attr : onnx_node.attribute()) { if (onnx_node_attr.name() == "group") { @@ -161,7 +162,6 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod attr->channelOut = dims[0]; attr->channelIn = dims[3] * attr->group; } - attr->format = schema::Format::Format_NCHW; attr->hasBias = onnx_node.input().size() == 3; if (onnx_node.op_type() == "ConvRelu" || onnx_node.op_type() == "Int8ConvRelu") { attr->activationType = schema::ActivationType_RELU; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 6dea5819df..be92193aca 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -244,6 +244,16 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, MS_LOG(ERROR) << "memcpy_s failed"; return RET_ERROR; } + // set quantParams to Int8GivenTensor. + std::unique_ptr quant_param = std::make_unique(); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + if (onnx_node_attr.name() == "Y_scale") { + quant_param->scale = onnx_node_attr.f(); + } else if (onnx_node_attr.name() == "Y_zero_point") { + quant_param->zeroPoint = static_cast(onnx_node_attr.i()); + } + } + tensor->quantParams.emplace_back(std::move(quant_param)); } else { MS_LOG(ERROR) << "unsupported data type " << tensor->dataType; return RET_ERROR; @@ -256,9 +266,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, } STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, schema::TensorT *dst_tensor, - TensorCache *tensor_cache, const QuantType &quantType, - schema::MetaGraphT *dst_graph) { + schema::CNodeT *dst_op, TensorCache *tensor_cache, + const QuantType &quantType, schema::MetaGraphT *dst_graph) { // change op_type() to name(), that is unique static bool interrupt = false; dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); @@ -267,7 +276,6 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, MS_LOG(DEBUG) << "onnx op name " << onnx_node.op_type() << ", dst op name: " << dst_op->name << ", input size " << onnx_node.input_size(); // get the real op type - SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); if (onnx_node.op_type() == "Loop") { NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); interrupt = true; @@ -305,6 +313,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, MS_LOG(ERROR) << "SetOpInputIndex failed"; return RET_ERROR; } + if (dst_op->primitive->value.type == schema::PrimitiveType_Conv2D) { + auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); + weight_tensor->format = dst_op->primitive->value.AsConv2D()->format; + } else if (dst_op->primitive->value.type == schema::PrimitiveType_DeConv2D) { + auto &weight_tensor = tensor_cache->GetCachedTensor().at(dst_op->inputIndex.at(kWeightIndex)); + weight_tensor->format = dst_op->primitive->value.AsDeConv2D()->format; + } // set op output index std::vector node_outputs; (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); @@ -314,6 +329,13 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, MS_LOG(ERROR) << "SetOpOutputIndex failed"; return RET_ERROR; } + auto &output_tensor = tensor_cache->GetCachedTensor().at(dst_op->outputIndex.front()); + if (output_tensor == nullptr) { + interrupt = true; + MS_LOG(ERROR) << "Output tensor of node " << onnx_node.op_type() << "is nullptr."; + return RET_ERROR; + } + SetOpQuantParams(onnx_graph, onnx_node, dst_op, output_tensor, tensor_cache); return RET_OK; } @@ -572,9 +594,7 @@ int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT } std::unique_ptr dst_op = std::make_unique(); - std::unique_ptr dst_tensor = std::make_unique(); - status_node = - ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType, dst_graph); + status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), &tensor_cache, quantType, dst_graph); if (status_node != RET_OK) { status = (status == RET_OK ? status_node : status); continue; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 8a970dee3e..42200806db 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -66,8 +66,8 @@ class OnnxModelParser : public ModelParser { TensorCache *tensor_cache, int *index); STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache, - const QuantType &quantType, schema::MetaGraphT *dst_graph); + schema::CNodeT *dst_op, TensorCache *tensor_cache, const QuantType &quantType, + schema::MetaGraphT *dst_graph); void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache, diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h index 9a8298beba..30d2db0033 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_node_parser.h @@ -24,7 +24,7 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "schema/inner/model_generated.h" - +#include "ir/dtype/type_id.h" namespace mindspore { namespace lite { class OnnxNodeParser { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc similarity index 57% rename from mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc rename to mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc index 9fb09ce987..bc7bb87931 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.cc @@ -14,14 +14,14 @@ * limitations under the License. */ -#include "tools/converter/parser/onnx/onnx_unuseful_node_parser.h" +#include "tools/converter/parser/onnx/onnx_quantize_parser.h" #include namespace mindspore { namespace lite { -STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::CNodeT *op) { - MS_LOG(DEBUG) << "onnx UnusefulNodeParser"; +STATUS OnnxQuantizeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx QuantizeDequantizeParser"; if (op == nullptr) { MS_LOG(ERROR) << "op is null"; return RET_NULL_PTR; @@ -32,30 +32,27 @@ STATUS OnnxUnusefulNodeParser::Parse(const onnx::GraphProto &onnx_graph, const o return RET_NULL_PTR; } + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed."; + return RET_NULL_PTR; + } if (onnx_node.op_type() == "Int8Quantize") { - std::unique_ptr attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; - } - op->primitive->value.type = schema::PrimitiveType_OnnxInt8Quantize; - op->primitive->value.value = attr.release(); + attr->srcT = kNumberTypeFloat32; + attr->dstT = kNumberTypeInt8; } else if (onnx_node.op_type() == "Int8Dequantize") { - std::unique_ptr attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; - } - op->primitive->value.type = schema::PrimitiveType_OnnxInt8Dequantize; - op->primitive->value.value = attr.release(); + attr->srcT = kNumberTypeInt8; + attr->dstT = kNumberTypeFloat32; } else { MS_LOG(ERROR) << "Unsupported nodeType: " << onnx_node.op_type().c_str(); return RET_ERROR; } + op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; + op->primitive->value.value = attr.release(); return RET_OK; } -OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxUnusefulNodeParser()); -OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxUnusefulNodeParser()); +OnnxNodeRegistrar g_onnxInt8QuantizeParser("Int8Quantize", new OnnxQuantizeParser()); +OnnxNodeRegistrar g_onnxInt8DequantizeParser("Int8Dequantize", new OnnxQuantizeParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.h similarity index 71% rename from mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h rename to mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.h index dd0c40372b..182a38b697 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unuseful_node_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_quantize_parser.h @@ -14,21 +14,21 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H #include "tools/converter/parser/onnx/onnx_node_parser.h" #include "tools/converter/parser/onnx/onnx_node_parser_registry.h" namespace mindspore { namespace lite { -class OnnxUnusefulNodeParser : public OnnxNodeParser { +class OnnxQuantizeParser : public OnnxNodeParser { public: - OnnxUnusefulNodeParser() : OnnxNodeParser("UnusefulNode") {} - ~OnnxUnusefulNodeParser() override = default; + OnnxQuantizeParser() : OnnxNodeParser("Quantize") {} + ~OnnxQuantizeParser() override = default; STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; }; } // namespace lite } // namespace mindspore -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX__UNUSEFUL_PARSER_H +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_QUANTIZE_PARSER_H diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc index bb6406daa8..e161e49ddc 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc @@ -79,7 +79,11 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D || op_type == schema::PrimitiveType_DeConv2D || op_type == schema::PrimitiveType_DeDepthwiseConv2D) { - param_value->set_format(schema::Format::Format_KCHW); + if (param_value->format() == schema::Format::Format_NHWC) { + param_value->set_format(schema::Format::Format_KHWC); + } else { + param_value->set_format(schema::Format::Format_KCHW); + } } else { MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " << conv_node->fullname_with_scope();