From a61b5b56d1f74922c92355656b7401573bd1eaef Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Tue, 26 Jan 2021 22:00:04 +0800 Subject: [PATCH] fix bug of graphdef transform --- .../tools/converter/graphdef_transform.cc | 11 ++++++++++ .../graph/nested_loop_expand_pass.cc | 20 +++++++++++++------ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index d480152194..eeb552b7f5 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -42,6 +42,7 @@ #include "tools/converter/legacy_optimizer/graph/select_pass.h" #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" +#include "tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h" using std::string; namespace mindspore::lite { @@ -276,6 +277,16 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } + { + Optimizer nestedLoopOptimizer; + nestedLoopOptimizer.AddPass(new (std::nothrow) NestedLoopExpandPass()); + status = nestedLoopOptimizer.Run(graphDefT); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "Run nestedLoopOptimizer graphPasses Failed"; + return status; + } + } + return RET_OK; } // namespace mindspore::lite } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc index 27d3afd9e0..bfb553020c 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc @@ -77,6 +77,20 @@ STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) { graph_->subGraph.at(idx) = nullptr; } + + + for (auto &node_idx : main_graph->nodeIndices) { + auto &node = graph_->nodes.at(node_idx); + if (node->primitive->value.type == PrimitiveType_Partial) { + auto &subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex; + for (auto i = 0; i < subgraph_idx; ++i) { + if (graph_->subGraph.at(subgraph_idx) == nullptr) { + subgraph_idx--; + } + } + } + } + for (auto it = graph_->subGraph.begin(); it != graph_->subGraph.end();) { if ((*it) == nullptr) { it = graph_->subGraph.erase(it); @@ -85,12 +99,6 @@ STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) { } } - for (auto &node : graph_->nodes) { - if (node->primitive->value.type == PrimitiveType_Partial) { - ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex -= subgraph_to_drop_.size(); - } - } - return RET_OK; }