!8709 Optimize TopoSort() and Infer() performance.

From: @zh_qh
Reviewed-by: 
Signed-off-by:
pull/8709/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0b3aa904c0

@ -97,7 +97,12 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; << ", please call 'context.set_context(max_call_depth=value)' to adjust this value.";
} }
const std::vector<AnfNodePtr> &all_nodes = TopoSort(func_node); const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType {
if (node->func_graph() != fg || node->isa<ValueNode>()) {
return EXCLUDE;
}
return FOLLOW;
});
for (const auto &node : all_nodes) { for (const auto &node : all_nodes) {
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString()

@ -158,6 +158,7 @@ class AnfNode : public Base {
return os; return os;
} }
size_t seen_{0}; size_t seen_{0};
size_t extra_seen_{0};
template <typename T> template <typename T>
void set_user_data(const std::string &key, const std::shared_ptr<T> &value) { void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {

@ -36,50 +36,42 @@ namespace mindspore {
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
size_t seen = NewSeenGeneration(); size_t seen = NewSeenGeneration();
std::deque<AnfNodePtr> todo(1024); std::deque<AnfNodePtr> todo(1024);
std::unordered_map<AnfNodePtr, size_t> rank;
std::vector<AnfNodePtr> res; std::vector<AnfNodePtr> res;
todo.clear(); todo.clear();
todo.push_back(root); todo.push_back(root);
while (!todo.empty()) { while (!todo.empty()) {
AnfNodePtr node = todo.back(); AnfNodePtr node = todo.back();
if (node == nullptr || node->seen_ == seen) { if (node == nullptr) {
todo.pop_back(); todo.pop_back();
continue; continue;
} }
if (rank.find(node) != rank.end() && rank[node] != todo.size()) { if (node->extra_seen_ == seen) { // We use extra_seen_ as finish flag
MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2); todo.pop_back();
continue;
} }
rank[node] = todo.size();
bool cont = false;
auto incl = include(node); auto incl = include(node);
if (incl == FOLLOW) { if (node->seen_ == seen) { // We use seen_ as checking flag
auto succs = succ(node);
for (const auto i : succs) {
if ((i != nullptr && i->seen_ != seen)
// Handle the case for 2 subgraphs calls each other.
// If the ValueNodeGraph's return is already in the todo list, do not follow it.
&& !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) &&
(i->func_graph()->get_return() == i))) {
todo.push_back(i);
cont = true;
}
}
} else if (incl == NOFOLLOW) {
// do nothing
} else if (incl == EXCLUDE) {
node->seen_ = seen;
todo.pop_back(); todo.pop_back();
if (incl != EXCLUDE) {
res.push_back(node);
}
node->extra_seen_ = seen;
continue; continue;
} else {
MS_LOG(EXCEPTION) << "include(node) must return one of: \"follow\", \"nofollow\", \"exclude\"";
} }
if (cont) { if (node->seen_ == seen && node->extra_seen_ != seen) {
continue; MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2);
} }
node->seen_ = seen; node->seen_ = seen;
res.push_back(node); if (incl == FOLLOW) {
todo.pop_back(); auto succs = succ(node);
(void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen](const AnfNodePtr &next) {
return next != nullptr && next->seen_ != seen &&
(next->func_graph() == nullptr || next->func_graph()->get_return() != next);
});
} else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE
MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\"";
}
} }
return res; return res;
} }

Loading…
Cancel
Save