|
|
|
@ -36,50 +36,42 @@ namespace mindspore {
|
|
|
|
|
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
|
|
|
|
|
size_t seen = NewSeenGeneration();
|
|
|
|
|
std::deque<AnfNodePtr> todo(1024);
|
|
|
|
|
std::unordered_map<AnfNodePtr, size_t> rank;
|
|
|
|
|
std::vector<AnfNodePtr> res;
|
|
|
|
|
todo.clear();
|
|
|
|
|
todo.push_back(root);
|
|
|
|
|
|
|
|
|
|
while (!todo.empty()) {
|
|
|
|
|
AnfNodePtr node = todo.back();
|
|
|
|
|
if (node == nullptr || node->seen_ == seen) {
|
|
|
|
|
if (node == nullptr) {
|
|
|
|
|
todo.pop_back();
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (rank.find(node) != rank.end() && rank[node] != todo.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2);
|
|
|
|
|
if (node->extra_seen_ == seen) { // We use extra_seen_ as finish flag
|
|
|
|
|
todo.pop_back();
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
rank[node] = todo.size();
|
|
|
|
|
bool cont = false;
|
|
|
|
|
auto incl = include(node);
|
|
|
|
|
if (incl == FOLLOW) {
|
|
|
|
|
auto succs = succ(node);
|
|
|
|
|
for (const auto i : succs) {
|
|
|
|
|
if ((i != nullptr && i->seen_ != seen)
|
|
|
|
|
// Handle the case for 2 subgraphs calls each other.
|
|
|
|
|
// If the ValueNodeGraph's return is already in the todo list, do not follow it.
|
|
|
|
|
&& !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) &&
|
|
|
|
|
(i->func_graph()->get_return() == i))) {
|
|
|
|
|
todo.push_back(i);
|
|
|
|
|
cont = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (incl == NOFOLLOW) {
|
|
|
|
|
// do nothing
|
|
|
|
|
} else if (incl == EXCLUDE) {
|
|
|
|
|
node->seen_ = seen;
|
|
|
|
|
if (node->seen_ == seen) { // We use seen_ as checking flag
|
|
|
|
|
todo.pop_back();
|
|
|
|
|
if (incl != EXCLUDE) {
|
|
|
|
|
res.push_back(node);
|
|
|
|
|
}
|
|
|
|
|
node->extra_seen_ = seen;
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "include(node) must return one of: \"follow\", \"nofollow\", \"exclude\"";
|
|
|
|
|
}
|
|
|
|
|
if (cont) {
|
|
|
|
|
continue;
|
|
|
|
|
if (node->seen_ == seen && node->extra_seen_ != seen) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2);
|
|
|
|
|
}
|
|
|
|
|
node->seen_ = seen;
|
|
|
|
|
res.push_back(node);
|
|
|
|
|
todo.pop_back();
|
|
|
|
|
if (incl == FOLLOW) {
|
|
|
|
|
auto succs = succ(node);
|
|
|
|
|
(void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen](const AnfNodePtr &next) {
|
|
|
|
|
return next != nullptr && next->seen_ != seen &&
|
|
|
|
|
(next->func_graph() == nullptr || next->func_graph()->get_return() != next);
|
|
|
|
|
});
|
|
|
|
|
} else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE
|
|
|
|
|
MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\"";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|