|
|
|
@ -58,14 +58,32 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
|
|
|
|
|
node->extra_seen_ = seen;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (node->seen_ == seen && node->extra_seen_ != seen) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2);
|
|
|
|
|
}
|
|
|
|
|
node->seen_ = seen;
|
|
|
|
|
if (incl == FOLLOW) {
|
|
|
|
|
auto succs = succ(node);
|
|
|
|
|
(void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo),
|
|
|
|
|
[seen](const AnfNodePtr &next) { return next != nullptr && next->seen_ != seen; });
|
|
|
|
|
(void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen, todo](const AnfNodePtr &next) {
|
|
|
|
|
if (next == nullptr || next->extra_seen_ == seen) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (next->seen_ != seen) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (next->func_graph()->get_return() == next) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// To dump all nodes in a circle.
|
|
|
|
|
MS_LOG(ERROR) << "Graph cycle exists. Circle is: ";
|
|
|
|
|
size_t pos = 0;
|
|
|
|
|
auto circle_node_it = std::find(todo.begin(), todo.end(), next);
|
|
|
|
|
for (; circle_node_it != todo.end(); circle_node_it++) {
|
|
|
|
|
auto circle_node = *circle_node_it;
|
|
|
|
|
if (circle_node->seen_) {
|
|
|
|
|
MS_LOG(ERROR) << "#" << pos << ": " << circle_node->DebugString();
|
|
|
|
|
pos++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "Graph cycle exists, node " << next->DebugString(2);
|
|
|
|
|
});
|
|
|
|
|
} else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE
|
|
|
|
|
MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\"";
|
|
|
|
|
}
|
|
|
|
@ -138,10 +156,6 @@ std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
|
|
|
|
|
auto &inputs = node->cast<CNodePtr>()->inputs();
|
|
|
|
|
(void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
|
|
|
|
|
}
|
|
|
|
|
auto graph = node->func_graph();
|
|
|
|
|
if (graph->get_return() != nullptr) {
|
|
|
|
|
vecs.push_back(graph->get_return());
|
|
|
|
|
}
|
|
|
|
|
return vecs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|