|
|
|
@ -60,19 +60,36 @@ STATUS SingleSwitchPass::DoubleSwitchOutput() {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SingleSwitchPass::DoubleIdx(uint32_t *idx) {
|
|
|
|
|
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), *idx);
|
|
|
|
|
if (iter != switch_node_->outputIndex.end()) {
|
|
|
|
|
int pos = iter - switch_node_->outputIndex.begin();
|
|
|
|
|
*idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS SingleSwitchPass::UpdateSwitchUser() {
|
|
|
|
|
std::vector<CNodeT *> switch_users;
|
|
|
|
|
for (auto &node_idx : graph_->subGraph.at(this_subgraph_index_)->nodeIndices) {
|
|
|
|
|
auto &node = graph_->nodes.at(node_idx);
|
|
|
|
|
for (auto &idx : node->inputIndex) {
|
|
|
|
|
auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), idx);
|
|
|
|
|
if (iter != switch_node_->outputIndex.end()) {
|
|
|
|
|
if (IsContain(switch_node_->outputIndex, idx)) {
|
|
|
|
|
switch_users.push_back(node.get());
|
|
|
|
|
int pos = iter - switch_node_->outputIndex.begin();
|
|
|
|
|
idx = switch_node_->outputIndex.at(pos + switch_node_->outputIndex.size() / 2);
|
|
|
|
|
}
|
|
|
|
|
DoubleIdx(&idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// update graph switch user
|
|
|
|
|
for (auto &subgraph : graph_->subGraph) {
|
|
|
|
|
for (auto &idx : subgraph->outputIndices) {
|
|
|
|
|
DoubleIdx(&idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &idx : graph_->outputIndex) {
|
|
|
|
|
DoubleIdx(&idx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -307,7 +324,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
|
|
|
|
|
// get parameter input index k. subgraph name + “_input_" + "k"
|
|
|
|
|
auto pos = subgraph->name.size() + sizeof("_input_");
|
|
|
|
|
auto pos2 = tensor->name.find('_', pos);
|
|
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2);
|
|
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
|
|
|
|
|
partial_idx = std::stoi(idx_str);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -315,7 +332,7 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem
|
|
|
|
|
// get parameter input index k. subgraph name + “_output_" + "k"
|
|
|
|
|
auto pos = subgraph->name.size() + sizeof("_output_");
|
|
|
|
|
auto pos2 = tensor->name.find('_', pos);
|
|
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2);
|
|
|
|
|
auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1);
|
|
|
|
|
partial_idx = std::stoi(idx_str);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|