diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc index ae86c8692f..d296919264 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc @@ -60,19 +60,36 @@ STATUS SingleSwitchPass::DoubleSwitchOutput() { return RET_OK; } +void SingleSwitchPass::DoubleIdx(uint32_t *idx) { + auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), *idx); + if (iter != switch_node_->outputIndex.end()) { + int pos = iter - switch_node_->outputIndex.begin(); + *idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2); + } +} + STATUS SingleSwitchPass::UpdateSwitchUser() { std::vector switch_users; for (auto &node_idx : graph_->subGraph.at(this_subgraph_index_)->nodeIndices) { auto &node = graph_->nodes.at(node_idx); for (auto &idx : node->inputIndex) { - auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), idx); - if (iter != switch_node_->outputIndex.end()) { + if (IsContain(switch_node_->outputIndex, idx)) { switch_users.push_back(node.get()); - int pos = iter - switch_node_->outputIndex.begin(); - idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2); } + DoubleIdx(&idx); + } + } + // update graph switch user + for (auto &subgraph : graph_->subGraph) { + for (auto &idx : subgraph->outputIndices) { + DoubleIdx(&idx); } } + + for (auto &idx : graph_->outputIndex) { + DoubleIdx(&idx); + } + return RET_OK; } @@ -307,7 +324,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem // get parameter input index k. subgraph name + “_input_" + "k" auto pos = subgraph->name.size() + sizeof("_input_"); auto pos2 = tensor->name.find('_', pos); - auto idx_str = tensor->name.substr(pos - 1, pos2); + auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); partial_idx = std::stoi(idx_str); } @@ -315,7 +332,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem // get parameter input index k. subgraph name + “_output_" + "k" auto pos = subgraph->name.size() + sizeof("_output_"); auto pos2 = tensor->name.find('_', pos); - auto idx_str = tensor->name.substr(pos - 1, pos2); + auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); partial_idx = std::stoi(idx_str); } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h index 12360701c9..f65eab862a 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h @@ -56,6 +56,7 @@ class SingleSwitchPass { const std::vector &subgraph_nodes); std::unique_ptr NewTensor(const std::unique_ptr &in_tensor); void RemoveUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph); + void DoubleIdx(uint32_t *idx); const size_t kSwitchCondIndex = 0; const size_t kSwitchBodyIndex = 1;