|
|
|
@ -43,6 +43,35 @@ STATUS SwitchPass::Run(mindspore::schema::MetaGraphT *graph) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// remove empty subgraphs
|
|
|
|
|
std::vector<std::unique_ptr<SubGraphT>> new_sub_graphs;
|
|
|
|
|
std::map<uint32_t, uint32_t> sub_graph_index_map;
|
|
|
|
|
for (size_t i = 0; i < graph->subGraph.size(); ++i) {
|
|
|
|
|
auto &sub_graph = graph->subGraph.at(i);
|
|
|
|
|
if (!sub_graph->nodeIndices.empty()) {
|
|
|
|
|
new_sub_graphs.emplace_back(std::move(sub_graph));
|
|
|
|
|
sub_graph_index_map.emplace(std::make_pair(i, new_sub_graphs.size() - 1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
graph->subGraph.swap(new_sub_graphs);
|
|
|
|
|
for (size_t i = 0; i < graph->nodes.size(); ++i) {
|
|
|
|
|
auto &node = graph->nodes.at(i);
|
|
|
|
|
auto type = node->primitive->value.type;
|
|
|
|
|
if (type != schema::PrimitiveType_Partial) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(node->primitive != nullptr);
|
|
|
|
|
MS_ASSERT(node->primitive->value.AsPartial() != nullptr);
|
|
|
|
|
auto partial_prim = node->primitive->value.AsPartial();
|
|
|
|
|
if (partial_prim->subGraphIndex == -1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (sub_graph_index_map.find(partial_prim->subGraphIndex) == sub_graph_index_map.end()) {
|
|
|
|
|
MS_LOG(ERROR) << "subGraphIndex is illegal";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
partial_prim->subGraphIndex = sub_graph_index_map[partial_prim->subGraphIndex];
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -308,7 +337,7 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() {
|
|
|
|
|
auto origin_switch_outputs = switch_node_->outputIndex;
|
|
|
|
|
switch_node_->outputIndex.clear();
|
|
|
|
|
for (size_t i = 3; i < switch_node_->inputIndex.size(); i++) {
|
|
|
|
|
auto &switch_in_tensor = graph_->allTensors.at(i);
|
|
|
|
|
auto &switch_in_tensor = graph_->allTensors.at(switch_node_->inputIndex[i]);
|
|
|
|
|
auto tensor = NewTensor(switch_in_tensor);
|
|
|
|
|
graph_->allTensors.push_back(std::move(tensor));
|
|
|
|
|
switch_node_->outputIndex.push_back(graph_->allTensors.size() - 1);
|
|
|
|
@ -581,6 +610,9 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche
|
|
|
|
|
|
|
|
|
|
STATUS SingleSwitchPass::ConcatCondSubgraphInputAndOutput() {
|
|
|
|
|
if (first_subgraph_index_ == -1) {
|
|
|
|
|
MS_ASSERT(first_partial_node_->primitive != nullptr);
|
|
|
|
|
MS_ASSERT(first_partial_node_->primitive->value.AsPartial() != nullptr);
|
|
|
|
|
first_partial_node_->primitive->value.AsPartial()->subGraphIndex = -1;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
int ret = UpdateSubgraphInput(first_subgraph_index_, first_partial_node_, first_graph_nodes_);
|
|
|
|
@ -599,6 +631,9 @@ STATUS SingleSwitchPass::ConcatCondSubgraphInputAndOutput() {
|
|
|
|
|
|
|
|
|
|
STATUS SingleSwitchPass::ConcatBodySubgraphInputAndOutput() {
|
|
|
|
|
if (second_subgraph_index_ == -1) {
|
|
|
|
|
MS_ASSERT(first_partial_node_->primitive != nullptr);
|
|
|
|
|
MS_ASSERT(first_partial_node_->primitive->value.AsPartial() != nullptr);
|
|
|
|
|
first_partial_node_->primitive->value.AsPartial()->subGraphIndex = -1;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
int ret = UpdateSubgraphInput(second_subgraph_index_, second_partial_node_, second_graph_nodes_);
|
|
|
|
|