From 067beb25908f41fd1ba46b1fe3ea73db4ad9ffad Mon Sep 17 00:00:00 2001 From: wsc Date: Thu, 13 Aug 2020 10:10:24 +0800 Subject: [PATCH] Modify the method for getting output index of metagraph. --- mindspore/lite/src/ops/reduce.cc | 2 +- .../src/runtime/kernel/arm/fp32/reduce.cc | 2 +- .../parser/tflite/tflite_model_parser.cc | 31 +++++++++++++------ .../parser/tflite/tflite_model_parser.h | 5 ++- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index 76ce819977..0c1e8d8925 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -58,7 +58,7 @@ int Reduce::InferShape(std::vector inputs_, std::vector((*axes)[idx]) == i) { + if (static_cast((*axes)[idx]) == i || static_cast((*axes)[idx] + in_shape.size()) == i) { reduce_axis = true; break; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc index 256369f218..e9929aced4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce.cc @@ -71,7 +71,7 @@ int ReduceCPUKernel::CheckParameters() { return RET_ERROR; } for (auto i = 0; i < num_axes_; i++) { - if (axes_[i] < -static_cast(input_rank) || static_cast(axes_[i]) >= input_rank) { + if (axes_[i] < -static_cast(input_rank) || axes_[i] >= static_cast(input_rank)) { MS_LOG(ERROR) << "Reduce got invalid axis " << axes_[i] << ", axis should be in [" << -static_cast(input_rank) << ", " << input_rank - 1 << "]."; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 8c7163496f..d240a9f295 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -236,18 +236,31 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr } } -void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, +void TfliteModelParser::SetGraphTensorIndex(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_model, + const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef) { - auto opGraph = OpGraphT::Build(subGraphDef); auto graphInputs = tensorCache.GetGraphInputs(); - auto graphOutputs = opGraph->GetOutputNode(); - subGraphDef->inputIndex.assign(graphInputs.begin(), graphInputs.end()); - for (const auto &output : graphOutputs) { - auto op = opMap[output->ID()]; - for (auto outputIndex : op->outputIndex) { - subGraphDef->outputIndex.emplace_back(outputIndex); + for (auto outputIndex : tflite_subgraph->outputs) { + int i = 0; + bool found = false; + for (const auto &tfliteOp : tflite_subgraph->operators) { + int j = 0; + auto opType = GetTfliteNodeType(tfliteOp, tflite_model); + std::string opName = opType + "-" + std::to_string(i++); + for (auto opOutputIndex : tfliteOp->outputs) { + if (outputIndex == opOutputIndex) { + subGraphDef->outputIndex.emplace_back(opMap[opName]->outputIndex[j]); + found = true; + break; + } + j++; + } + if (found) { + break; + } } } } @@ -284,7 +297,7 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st return nullptr; } - SetGraphTensorIndex(tensorCache, subGraph.get()); + SetGraphTensorIndex(tflite_subgraph, tflite_model, tensorCache, subGraph.get()); SetAllTensors(tensorCache, subGraph.get()); return subGraph.release(); } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 2379bd2632..ed861bb866 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -50,7 +50,10 @@ class TfliteModelParser : public ModelParser { void SetInputTensor(const std::unique_ptr &tflite_subgraph, TensorCache *tensor_cache); - void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef); + void SetGraphTensorIndex(const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_model, + const mindspore::lite::TensorCache &tensorCache, + schema::MetaGraphT *subGraphDef); STATUS ParseOp(const std::unique_ptr &tflite_model, const std::unique_ptr &tflite_subgraph, schema::MetaGraphT *sub_graph,