fix bug of graphdef transform

pull/11685/head
mengyuanli 4 years ago
parent 49e9b80328
commit a61b5b56d1

@ -42,6 +42,7 @@
#include "tools/converter/legacy_optimizer/graph/select_pass.h" #include "tools/converter/legacy_optimizer/graph/select_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h" #include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
#include "tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h"
using std::string; using std::string;
namespace mindspore::lite { namespace mindspore::lite {
@ -276,6 +277,16 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
} }
} }
{
Optimizer nestedLoopOptimizer;
nestedLoopOptimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
status = nestedLoopOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run nestedLoopOptimizer graphPasses Failed";
return status;
}
}
return RET_OK; return RET_OK;
} // namespace mindspore::lite } // namespace mindspore::lite
} // namespace mindspore::lite } // namespace mindspore::lite

@ -77,6 +77,20 @@ STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) {
graph_->subGraph.at(idx) = nullptr; graph_->subGraph.at(idx) = nullptr;
} }
for (auto &node_idx : main_graph->nodeIndices) {
auto &node = graph_->nodes.at(node_idx);
if (node->primitive->value.type == PrimitiveType_Partial) {
auto &subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex;
for (auto i = 0; i < subgraph_idx; ++i) {
if (graph_->subGraph.at(subgraph_idx) == nullptr) {
subgraph_idx--;
}
}
}
}
for (auto it = graph_->subGraph.begin(); it != graph_->subGraph.end();) { for (auto it = graph_->subGraph.begin(); it != graph_->subGraph.end();) {
if ((*it) == nullptr) { if ((*it) == nullptr) {
it = graph_->subGraph.erase(it); it = graph_->subGraph.erase(it);
@ -85,12 +99,6 @@ STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) {
} }
} }
for (auto &node : graph_->nodes) {
if (node->primitive->value.type == PrimitiveType_Partial) {
((schema::PartialT *)(node->primitive->value.value))->subGraphIndex -= subgraph_to_drop_.size();
}
}
return RET_OK; return RET_OK;
} }

Loading…
Cancel
Save