From a4aa06db534e6078d0d7933e11ab503c7ea97573 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Fri, 11 Sep 2020 13:08:58 +0800 Subject: [PATCH] mslite fix onnx conv\reshape\dropout\matmul parser, fix infershape_pass bug --- mindspore/lite/schema/ops.fbs | 1 + mindspore/lite/src/ops/matmul.cc | 2 +- mindspore/lite/src/ops/primitive_c.cc | 5 ++++ .../legacy_optimizer/graph/infershape_pass.cc | 30 ++++++++----------- .../converter/parser/onnx/onnx_conv_parser.cc | 10 +++++-- .../parser/onnx/onnx_dropout_parser.cc | 2 +- .../parser/onnx/onnx_matmul_parser.cc | 3 ++ .../parser/onnx/onnx_reshape_parser.cc | 29 ++++++------------ .../parser/tflite/tflite_arithmetic_parser.cc | 2 +- 9 files changed, 42 insertions(+), 42 deletions(-) diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 02496c582a..53a0a7212f 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -549,6 +549,7 @@ table NetOutput { } table MatMul { + broadcast : bool = false; transposeA : bool = false; transposeB : bool = false; } diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index 65fb9a4e4f..fc6f01e043 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -80,7 +80,7 @@ int MatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: MS_LOG(ERROR) << "value_as_MatMul return nullptr"; return RET_ERROR; } - auto val_offset = schema::CreateMatMul(*fbb, attr->transposeA(), attr->transposeB()); + auto val_offset = schema::CreateMatMul(*fbb, attr->broadcast(), attr->transposeA(), attr->transposeB()); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_MatMul, val_offset.o); fbb->Finish(prim_offset); return RET_OK; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 33dfd55ff7..186b427290 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -122,6 +122,7 @@ #include "src/ops/l2_norm.h" #include "src/ops/sparse_to_dense.h" #include "src/ops/detection_post_process.h" +#include "src/ops/dropout.h" #ifdef PRIMITIVE_WRITEABLE #include "tools/converter/quantizer/quantize_util.h" #endif @@ -617,6 +618,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new SparseToDense(primitive); case schema::PrimitiveType_DetectionPostProcess: return new DetectionPostProcess(primitive); + case schema::PrimitiveType_Dropout: + return new Dropout(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: @@ -866,6 +869,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { return NewPrimitiveC(primitive); case schema::PrimitiveType_DetectionPostProcess: return NewPrimitiveC(primitive); + case schema::PrimitiveType_Dropout: + return NewPrimitiveC(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index ca857fc4a0..6bb15bcbfa 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -38,26 +38,22 @@ std::vector ConvertTensorToLiteTensor(MetaGraphT *graph, const std::ve MS_LOG(ERROR) << "lite tensor is nullptr"; return std::vector(); } - // reshape op must get tensor data to infershape - if (node_type == schema::PrimitiveType_Reshape && i == 1 && tensorT->nodeType == NodeType_ValueNode) { - auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t); - // when tensorT as param input - if (lite_tensor_size == 0) { - return std::vector(); - } - auto ret = lite_tensor->MallocData(); - if (ret != 0) { - MS_LOG(ERROR) << "Malloc tensor data failed"; - return std::vector(); - } - ret = memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy error: " << ret; - return std::vector(); - } + auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t); + // when tensorT as param input + if (lite_tensor_size == 0) { lite_tensors.emplace_back(lite_tensor.release()); continue; } + auto ret = lite_tensor->MallocData(); + if (ret != 0) { + MS_LOG(ERROR) << "Malloc tensor data failed"; + return std::vector(); + } + ret = memcpy_s(lite_tensor->MutableData(), lite_tensor->Size(), tensorT->data.data(), lite_tensor_size); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return std::vector(); + } lite_tensors.emplace_back(lite_tensor.release()); } return lite_tensors; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc index fcb258fb2e..5a67a11abc 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_conv_parser.cc @@ -70,7 +70,13 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - + // set default params + attr->strideH = 1; + attr->strideW = 1; + attr->dilateH = 1; + attr->dilateW = 1; + attr->group = 1; + attr->padMode = schema::PadMode_NOTSET; // set opdef each attr params for (const auto &onnx_node_attr : onnx_node.attribute()) { if (onnx_node_attr.name() == "group") { @@ -165,7 +171,7 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod attr->activationType = schema::ActivationType_NO_ACTIVATION; } - if (attr->group != 1) { + if (attr->group == attr->channelOut) { if (!ParseGroupConvolution(attr, op)) { MS_LOG(ERROR) << "Convert Convolution to Depthwise failed"; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc index 7c0f593191..a6bdf6cbd1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_dropout_parser.cc @@ -41,7 +41,7 @@ STATUS OnnxDropoutParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "ratio") { - attr->ratio = static_cast(onnx_node_attr.i()); + attr->ratio = static_cast(onnx_node_attr.f()); } } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc index 4c4d8c0e4c..435520c70d 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_matmul_parser.cc @@ -42,6 +42,9 @@ STATUS OnnxMatmulParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N float beta = 1.0f; for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "broadcast") { + attr->broadcast = static_cast(onnx_node_attr.i()); + } if (attribute_name == "transA") { attr->transposeA = static_cast(onnx_node_attr.i()); } else if (attribute_name == "transB") { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc index 56b06a2f50..b976fc474a 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_reshape_parser.cc @@ -39,29 +39,18 @@ STATUS OnnxReshapeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx:: return RET_NULL_PTR; } - attr->format = schema::Format::Format_NCHW; - std::vector params; - for (int i = 0; i < onnx_node.input_size(); ++i) { - const auto &input_name = onnx_node.input(i); - for (const auto &it : onnx_graph.initializer()) { - if (it.name() == input_name) { - params.emplace_back(it); - break; + attr->format = schema::Format_NCHW; + std::vector shape; + shape.clear(); + for (const auto &onnx_node_attr : onnx_node.attribute()) { + const auto &attribute_name = onnx_node_attr.name(); + if (attribute_name == "shape") { + for (int i = 0; i < onnx_node_attr.ints_size(); ++i) { + shape.push_back(static_cast(onnx_node_attr.ints(i))); } } } - if (params.empty()) { - MS_LOG(DEBUG) << "shape from another op other than const initializer"; - } else { - if (params.size() != 1) { - MS_LOG(ERROR) << "shape param num is " << params.size() << ", not equal to 1"; - return RET_ERROR; - } - - for (int i = 0; i < params[0].int64_data_size(); ++i) { - attr->shape.emplace_back(params[0].int64_data(i)); - } - } + attr->shape = shape; op->primitive->value.type = schema::PrimitiveType_Reshape; op->primitive->value.value = attr.release(); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index 4cb191f7b4..34b2c71905 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -144,7 +144,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - attr->power = 0.0f; + attr->power = 1.0f; attr->scale = 1.0f; attr->shift = 0.0f; op->primitive->value.type = schema::PrimitiveType_Power;