!10256 [MS][LITE]fix bug of switch pass

From: @mengyuanli
Reviewed-by: @hangangqiang,@zhang_xue_tong
Signed-off-by: @hangangqiang
pull/10256/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b52eab010a

@ -60,19 +60,36 @@ STATUS SingleSwitchPass::DoubleSwitchOutput() {
return RET_OK;
}
void SingleSwitchPass::DoubleIdx(uint32_t *idx) {
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), *idx);
if (iter != switch_node_->outputIndex.end()) {
int pos = iter - switch_node_->outputIndex.begin();
*idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2);
}
}
STATUS SingleSwitchPass::UpdateSwitchUser() {
std::vector<CNodeT *> switch_users;
for (auto &node_idx : graph_->subGraph.at(this_subgraph_index_)->nodeIndices) {
auto &node = graph_->nodes.at(node_idx);
for (auto &idx : node->inputIndex) {
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), idx);
if (iter != switch_node_->outputIndex.end()) {
if (IsContain(switch_node_->outputIndex, idx)) {
switch_users.push_back(node.get());
int pos = iter - switch_node_->outputIndex.begin();
idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2);
}
DoubleIdx(&idx);
}
}
// update graph switch user
for (auto &subgraph : graph_->subGraph) {
for (auto &idx : subgraph->outputIndices) {
DoubleIdx(&idx);
}
}
for (auto &idx : graph_->outputIndex) {
DoubleIdx(&idx);
}
return RET_OK;
}
@ -307,7 +324,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
// get parameter input index k. subgraph name + “_input_" + "k"
auto pos = subgraph->name.size() + sizeof("_input_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2);
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
partial_idx = std::stoi(idx_str);
}
@ -315,7 +332,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
// get parameter input index k. subgraph name + “_output_" + "k"
auto pos = subgraph->name.size() + sizeof("_output_");
auto pos2 = tensor->name.find('_', pos);
auto idx_str = tensor->name.substr(pos - 1, pos2);
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
partial_idx = std::stoi(idx_str);
}

@ -56,6 +56,7 @@ class SingleSwitchPass {
const std::vector<schema::CNodeT *> &subgraph_nodes);
std::unique_ptr<schema::TensorT> NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor);
void RemoveUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph);
void DoubleIdx(uint32_t *idx);
const size_t kSwitchCondIndex = 0;
const size_t kSwitchBodyIndex = 1;

Loading…
Cancel
Save