From 446dce0096d5c3dbb2b2e00e3775a515b40747f0 Mon Sep 17 00:00:00 2001 From: biffex Date: Tue, 5 May 2020 16:38:44 +0800 Subject: [PATCH] [ir] add seen generation to accelerate traversing the whole graph --- mindspore/ccsrc/ir/anf.cc | 6 ++++++ mindspore/ccsrc/ir/anf.h | 4 ++++ mindspore/ccsrc/optimizer/opt.cc | 27 +++++++++++++++++++-------- mindspore/ccsrc/optimizer/opt.h | 4 ++-- mindspore/ccsrc/utils/graph_utils.cc | 17 +++++++++-------- 5 files changed, 40 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/ir/anf.cc b/mindspore/ccsrc/ir/anf.cc index dd86e46713..50fe184d3f 100644 --- a/mindspore/ccsrc/ir/anf.cc +++ b/mindspore/ccsrc/ir/anf.cc @@ -227,6 +227,12 @@ bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { } return false; } + +size_t NewSeenGeneration() { + static size_t seen_generation = 0; + return ++seen_generation; +} + namespace id_generator { static std::unordered_map node_ids; std::string get_id(const AnfNodePtr &node) { diff --git a/mindspore/ccsrc/ir/anf.h b/mindspore/ccsrc/ir/anf.h index 16ccb15c43..d3da155b50 100644 --- a/mindspore/ccsrc/ir/anf.h +++ b/mindspore/ccsrc/ir/anf.h @@ -155,6 +155,7 @@ class AnfNode : public Base { os << node.ToString(); return os; } + size_t seen_{0}; protected: // Hold a weak ref to Graph as Graph also hold ref to AnfNode. @@ -429,6 +430,9 @@ inline S GetValueNode(const AnfNodePtr &node) { auto s = value->cast(); return s; } + +size_t NewSeenGeneration(); + namespace id_generator { std::string get_id(const AnfNodePtr &node); void reset_id(); diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index 0dbaf1107f..987c3c27bc 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -90,20 +90,26 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, const SubstitutionPtr &transform) const { +#ifdef ENABLE_PROFILE + double start = GetTime(); +#endif FuncGraphManagerPtr manager = optimizer->manager(); - std::unordered_set seen_node; - std::deque todo{root_node}; + auto seen = NewSeenGeneration(); + // 1024 is for the initial capacity of deque + std::deque todo(1024); + todo.push_back(root_node); bool changes = false; + auto &all_nodes = manager->all_nodes(); while (!todo.empty()) { AnfNodePtr node = todo.front(); todo.pop_front(); // check whether this node has been matched. - if (seen_node.find(node) != seen_node.end() || !manager->all_nodes().contains(node)) { + if (node == nullptr || node->seen_ == seen || !all_nodes.contains(node)) { continue; } - (void)seen_node.insert(node); + node->seen_ = seen; // select nodes that this transform can be applied. bool is_match = transform->predicate_(node); @@ -114,6 +120,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo auto ret = (*transform)(optimizer, node); if (ret != nullptr && ret != node) { change = true; + changes = true; #ifdef ENABLE_PROFILE double t = GetTime(); #endif @@ -139,16 +146,20 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo if (change && node_users.find(node) != node_users.end()) { for (auto &use : node_users[node]) { auto use_node = use.first; + if (use_node == nullptr) { + continue; + } todo.push_back(use_node); - if (seen_node.find(use_node) != seen_node.end()) { - (void)seen_node.erase(use_node); + if (use_node->seen_ == seen) { + use_node->seen_--; } } } - - changes = changes || change; } +#ifdef ENABLE_PROFILE + MsProfile::StatTime("opt.transform", GetTime() - start); +#endif return changes; } diff --git a/mindspore/ccsrc/optimizer/opt.h b/mindspore/ccsrc/optimizer/opt.h index 24191998e8..fb0bdc58be 100644 --- a/mindspore/ccsrc/optimizer/opt.h +++ b/mindspore/ccsrc/optimizer/opt.h @@ -48,8 +48,8 @@ class Substitution { PredicateFuncType predicate_{nullptr}; // an enum to mark this Substitution relation to renormalize pass RenormAction renorm_action_; - explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, - const RenormAction &renorm_action) + Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, + const RenormAction &renorm_action) : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} ~Substitution() = default; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; diff --git a/mindspore/ccsrc/utils/graph_utils.cc b/mindspore/ccsrc/utils/graph_utils.cc index 0801622549..6a4ee58e30 100644 --- a/mindspore/ccsrc/utils/graph_utils.cc +++ b/mindspore/ccsrc/utils/graph_utils.cc @@ -46,17 +46,18 @@ class DeepFirstSearcher : public AnfVisitor { if (root == nullptr) { return res_; } + seen_ = NewSeenGeneration(); Visit(root); return res_; } void Visit(const AnfNodePtr &node) override { MS_EXCEPTION_IF_NULL(node); - if (seen_.count(node) != 0) { + if (node->seen_ == seen_) { return; } - (void)seen_.insert(node); + node->seen_ = seen_; auto incl = include_(node); if (incl == EXCLUDE) { @@ -70,9 +71,9 @@ class DeepFirstSearcher : public AnfVisitor { } private: + size_t seen_{0}; IncludeFunc include_; std::vector res_{}; - std::set seen_{}; }; class DeepScopedGraphSearcher : public DeepFirstSearcher { @@ -174,14 +175,14 @@ std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl } std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { - std::unordered_set done; + size_t seen = NewSeenGeneration(); std::list todo(1, root); std::unordered_map rank; std::vector res; while (!todo.empty()) { AnfNodePtr node = todo.back(); - if (done.find(node) != done.end()) { + if (node == nullptr || node->seen_ == seen) { todo.pop_back(); continue; } @@ -194,7 +195,7 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c if (incl == FOLLOW) { auto succs = succ(node); for (const auto i : succs) { - if ((done.find(i) == done.end()) + if ((i != nullptr && i->seen_ != seen) // Handle the case for 2 subgraphs calls each other. // If the ValueNodeGraph's return is already in the todo list, do not follow it. && !((std::find(todo.begin(), todo.end(), i) != todo.end()) && (i->func_graph() != nullptr) && @@ -206,7 +207,7 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c } else if (incl == NOFOLLOW) { // do nothing } else if (incl == EXCLUDE) { - (void)done.insert(node); + node->seen_ = seen; todo.pop_back(); continue; } else { @@ -215,7 +216,7 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c if (cont) { continue; } - (void)done.insert(node); + node->seen_ = seen; res.push_back(node); todo.pop_back(); }