From d79cfdbc69a00d228bb593297180a0cfdf56a4bf Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Fri, 18 Dec 2020 10:08:15 +0800 Subject: [PATCH] Check if circle exists in a graph. --- mindspore/core/ir/graph_utils.cc | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/mindspore/core/ir/graph_utils.cc b/mindspore/core/ir/graph_utils.cc index cff06ee891..6d7e251dc1 100644 --- a/mindspore/core/ir/graph_utils.cc +++ b/mindspore/core/ir/graph_utils.cc @@ -58,14 +58,32 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c node->extra_seen_ = seen; continue; } - if (node->seen_ == seen && node->extra_seen_ != seen) { - MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2); - } node->seen_ = seen; if (incl == FOLLOW) { auto succs = succ(node); - (void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), - [seen](const AnfNodePtr &next) { return next != nullptr && next->seen_ != seen; }); + (void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen, todo](const AnfNodePtr &next) { + if (next == nullptr || next->extra_seen_ == seen) { + return false; + } + if (next->seen_ != seen) { + return true; + } + if (next->func_graph()->get_return() == next) { + return false; + } + // To dump all nodes in a circle. + MS_LOG(ERROR) << "Graph cycle exists. Circle is: "; + size_t pos = 0; + auto circle_node_it = std::find(todo.begin(), todo.end(), next); + for (; circle_node_it != todo.end(); circle_node_it++) { + auto circle_node = *circle_node_it; + if (circle_node->seen_) { + MS_LOG(ERROR) << "#" << pos << ": " << circle_node->DebugString(); + pos++; + } + } + MS_LOG(EXCEPTION) << "Graph cycle exists, node " << next->DebugString(2); + }); } else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\""; } @@ -138,10 +156,6 @@ std::vector SuccDeeper(const AnfNodePtr &node) { auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } - auto graph = node->func_graph(); - if (graph->get_return() != nullptr) { - vecs.push_back(graph->get_return()); - } return vecs; }