Optimize the collectors of manager which listen to the graphs and nodes changes.

1. Remove the records of user graphs;
2. Remove the records of user value nodes;
3. Remove the records of user cnodes;
4. Add the records of users, and the API to access the users of graph, value node, and cnode;
5. Fix issue:User cnode record may point to its own graph, when combine the user(caller) and used one(callee);
6. Fix issue:User graphs never update itself after its first creation.
pull/847/head
Zhang Qinghua 5 years ago
parent 3dd369cefa
commit d43ad79b50

@ -263,18 +263,15 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
return used; return used;
} }
const FuncGraphCounterMap &FuncGraph::func_graph_users() { const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &users = mng->func_graph_users();
return users[shared_from_base<FuncGraph>()];
}
const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() {
auto mng = manager_.lock(); auto mng = manager_.lock();
if (mng == nullptr) {
MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
<< " NodeInfo: " << trace::GetDebugInfo(debug_info());
}
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
auto &users = mng->func_graph_user_cnodes(); auto &cnode = mng->func_graph_cnodes_index();
return users[shared_from_base<FuncGraph>()]; return cnode[shared_from_base<FuncGraph>()];
} }
FuncGraphPtr FuncGraph::parent() { FuncGraphPtr FuncGraph::parent() {

@ -37,6 +37,7 @@ namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>; using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>;
using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values"; const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline"; const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
@ -203,11 +204,8 @@ class FuncGraph : public FuncGraphBase {
// get all func graphs nested used by this func graph // get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total(); const FuncGraphSet &func_graphs_used_total();
// get all users of this func graph // get all user value nodes of this func graph
const FuncGraphCounterMap &func_graph_users(); const CNodeIndexCounterMap &func_graph_cnodes_index();
// get all user cnodes of this func graph
const AnfNodeCounterMap &func_graph_user_cnodes();
// Return the parent of this graph. // Return the parent of this graph.
FuncGraphPtr parent(); FuncGraphPtr parent();

@ -182,9 +182,11 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
} }
target_func_graph->set_return(return_node); target_func_graph->set_return(return_node);
auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; auto &cnodes = manager_->func_graph_cnodes_index()[func_graph];
for (auto &value_node : value_nodes) { for (auto &cnode : cnodes) {
CloneValueNode(value_node.first, target_func_graph); auto parent = cnode.first->first->cast<CNodePtr>();
auto valuenode = parent->input(cnode.first->second);
CloneValueNode(valuenode, target_func_graph);
} }
} }
@ -386,8 +388,8 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph
if (lift_params.empty()) { if (lift_params.empty()) {
return; return;
} }
for (auto &user : func_graph_user->func_graph_users()) { for (auto &cnode : func_graph_user->func_graph_cnodes_index()) {
LiftParameters(user.first, func_graph_user, lift_params); LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params);
} }
} }
@ -395,8 +397,8 @@ void Cloner::Lift() {
for (auto &func_graph_params : repl_func_graph_params_) { for (auto &func_graph_params : repl_func_graph_params_) {
auto &func_graph = func_graph_params.first; auto &func_graph = func_graph_params.first;
auto &params = func_graph_params.second; auto &params = func_graph_params.second;
for (auto &user : func_graph->func_graph_users()) { for (auto &cnode : func_graph->func_graph_cnodes_index()) {
LiftParameters(user.first, func_graph, params); LiftParameters(cnode.first->first->func_graph(), func_graph, params);
} }
} }
} }

@ -78,13 +78,16 @@ void FuncGraphManager::Reset() {
node_users_ = NodeUsersMap(); node_users_ = NodeUsersMap();
signals_ = std::make_shared<Signals>(); signals_ = std::make_shared<Signals>();
// FuncGraph --> AnfNode
nodes_ = std::make_shared<NodesCollector>(this); nodes_ = std::make_shared<NodesCollector>(this);
// FuncGraph --> {AnfNode, Count}
valuenodes_ = std::make_shared<ValueNodesCollector>(this); valuenodes_ = std::make_shared<ValueNodesCollector>(this);
free_variables_direct_ = std::make_shared<FVDirectCollector>(this); free_variables_direct_ = std::make_shared<FVDirectCollector>(this);
func_graph_valuenodes_ = std::make_shared<FuncGraphValueNodesCollector>(this); func_graph_cnodes_index_ = std::make_shared<FuncGraphUsersCNodeIndexCollector>(this);
// FuncGraph --> {FuncGraph, Count}
func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this); func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this);
func_graph_users_ = std::make_shared<FuncGraphUsersCollector>(this);
func_graph_user_cnodes_ = std::make_shared<FuncGraphUserNodesCollector>(this);
func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this); func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this);
func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this); func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this);
func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this); func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this);
@ -300,9 +303,9 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
continue; continue;
} }
MS_EXCEPTION_IF_NULL(func_graph_users_); MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_);
auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph];
if (!users.empty() && !ignore_users) { if (!users_cnode_index.empty() && !ignore_users) {
MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
continue; continue;
} }
@ -472,10 +475,6 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t
node->set_scope(scope); node->set_scope(scope);
} }
} }
for (auto &used : source->func_graphs_used()) {
(void)func_graph_users_->Inc(used.first, target, used.second);
(void)this->func_graph_users()[used.first].erase(source);
}
for (auto &child : this->func_graph_child_direct()[source]) { for (auto &child : this->func_graph_child_direct()[source]) {
(void)func_graph_parents_direct_->Inc(child.first, target, child.second); (void)func_graph_parents_direct_->Inc(child.first, target, child.second);
(void)this->func_graph_parents_direct()[child.first].erase(source); (void)this->func_graph_parents_direct()[child.first].erase(source);
@ -661,7 +660,9 @@ DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAna
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Inc(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
auto &d = count_nodes_map_[func_graph]; auto &d = count_nodes_map_[func_graph];
if (d.count(key) == 0) { if (d.count(key) == 0) {
d[key] = count; d[key] = count;
@ -672,7 +673,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodeP
return false; return false;
} }
bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Dec(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto &d = count_nodes_map_[func_graph]; auto &d = count_nodes_map_[func_graph];
if (d.count(key) != 0) { if (d.count(key) != 0) {
@ -682,7 +685,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP
} else { } else {
d[key] -= count; d[key] -= count;
if (d[key] < 0) { if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() MS_LOG(EXCEPTION) << "Count of key '" << key
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
} }
} }
@ -690,52 +693,15 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP
return false; return false;
} }
bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
if (count > 0) { if (count > 0) {
return Inc(func_graph, key, count); return Inc(func_graph, key, count);
} else if (count < 0) { } else if (count < 0) {
return Dec(func_graph, key, -count); return Dec(func_graph, key, -count);
} else { } else {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString() MS_LOG(EXCEPTION) << "Count of key '" << key
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
} else {
d[key] += count;
}
return false;
}
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
return true;
} else {
d[key] -= count;
if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
}
return false;
}
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
return Dec(func_graph, key, -count);
} else {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); << "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
} }
} }
@ -754,16 +720,21 @@ void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
(void)count_nodes_map_.erase(src); (void)count_nodes_map_.erase(src);
} }
// if inp is a graph ValueNode, this graph's FuncGraphValueNodesCollector's value is inp self void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp,
void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, EdgeProcessDirection direction) { EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) { if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(GetValueNode<FuncGraphPtr>(inp), inp, direction); (void)Mod(GetValueNode<FuncGraphPtr>(inp), std::make_shared<CNodeIndexPair>(std::make_pair(node, index)),
direction);
} }
} }
void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) { for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second); // Ignore the user graph who may own itself.
if (dst != it.first->first->func_graph()) {
(void)Inc(dst, it.first, it.second);
}
} }
(void)count_nodes_map_.erase(src); (void)count_nodes_map_.erase(src);
} }
@ -794,6 +765,45 @@ static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
return gn; return gn;
} }
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
} else {
d[key] += count;
}
return false;
}
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
return true;
} else {
d[key] -= count;
if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
}
return false;
}
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
return Dec(func_graph, key, -count);
} else {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(inp); MS_EXCEPTION_IF_NULL(inp);
@ -859,32 +869,6 @@ void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst)
(void)count_func_graphs_map_.erase(src); (void)count_func_graphs_map_.erase(src);
} }
void FuncGraphUsersCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(GetValueNode<FuncGraphPtr>(inp), node->func_graph(), direction);
}
}
void FuncGraphUsersCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr) {
// all graph use in src need to change to dst, so add dst user
(void)count_func_graphs_map_.erase(src);
}
void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(GetValueNode<FuncGraphPtr>(inp), node, direction);
}
}
void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_nodes_map_.erase(src);
}
void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) { void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) { if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) {
(void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction); (void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);

@ -100,8 +100,12 @@ struct Signals {
enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 }; enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 };
using CNodeIndexPair = std::pair<AnfNodePtr, int>;
using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>; using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>;
using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int>>; template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<ValueT, int, CollectorHash, CollectorEqual>>;
// analysis base class // analysis base class
class FuncGraphAnalysis { class FuncGraphAnalysis {
@ -174,46 +178,56 @@ class NodesCollector final : public DepCollector {
void OnDropNode(AnfNodePtr n) override; void OnDropNode(AnfNodePtr n) override;
}; };
class CounterFuncGraphCollector : public DepCollector { struct CNodeIndexHasher {
public: std::size_t operator()(const CNodeIndexPairPtr pair) const {
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} MS_EXCEPTION_IF_NULL(pair);
~CounterFuncGraphCollector() override = default; MS_EXCEPTION_IF_NULL(pair->first);
FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
// inherit from FuncGraphAnalysis }
size_t size() const override { return count_func_graphs_map_.size(); } };
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
protected: struct CNodeIndexEqual {
void ExtraReset() override { count_func_graphs_map_.clear(); } bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
if (lhs == nullptr || rhs == nullptr) {
return false;
}
if (lhs == rhs) {
return true;
}
if (lhs->first != rhs->first) {
return false;
}
if (lhs->second != rhs->second) {
return false;
}
return true;
}
}; };
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
class CounterAnfNodeCollector : public DepCollector { class CounterAnfNodeCollector : public DepCollector {
public: public:
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterAnfNodeCollector() override = default; ~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
size_t size() const override { return count_nodes_map_.size(); } size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); } void OnAddFuncGraph(FuncGraphPtr fg) final {
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
}
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
FuncGraphToAnfNodeCounterMap count_nodes_map_; FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
protected: protected:
void ExtraReset() override { count_nodes_map_.clear(); } void ExtraReset() override { count_nodes_map_.clear(); }
}; };
class ValueNodesCollector final : public CounterAnfNodeCollector { class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
public: public:
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~ValueNodesCollector() override = default; ~ValueNodesCollector() override = default;
@ -223,17 +237,19 @@ class ValueNodesCollector final : public CounterAnfNodeCollector {
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
}; };
class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { // Record the CNode and its input index, who points to the function graph.
class FuncGraphUsersCNodeIndexCollector final
: public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> {
public: public:
explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FuncGraphValueNodesCollector() override = default; ~FuncGraphUsersCNodeIndexCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
protected: protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
}; };
class FVDirectCollector final : public CounterAnfNodeCollector { class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
public: public:
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FVDirectCollector() override = default; ~FVDirectCollector() override = default;
@ -243,6 +259,25 @@ class FVDirectCollector final : public CounterAnfNodeCollector {
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
}; };
class CounterFuncGraphCollector : public DepCollector {
public:
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterFuncGraphCollector() override = default;
FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; }
// inherit from FuncGraphAnalysis
size_t size() const override { return count_func_graphs_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
protected:
void ExtraReset() override { count_func_graphs_map_.clear(); }
};
class FuncGraphChildDirect final : public CounterFuncGraphCollector { class FuncGraphChildDirect final : public CounterFuncGraphCollector {
public: public:
explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
@ -279,28 +314,6 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
}; };
// graph's all user graphs: key is g, value is graphs who used g
class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphUsersCollector() override = default;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
// graph's all user cnodes: key is g, value is cnodes who used g
class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
public:
explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphUserNodesCollector() override = default;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
public: public:
explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
@ -433,7 +446,9 @@ class ScopeComputer final : public DepComputer {
using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>; using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;
class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector { class FVTotalComputer final : public DepComputer,
public CounterAnfNodeCollector<AnfNodePtr>,
public CounterFuncGraphCollector {
public: public:
explicit FVTotalComputer(const FuncGraphManager *m) explicit FVTotalComputer(const FuncGraphManager *m)
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {} : DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
@ -549,18 +564,18 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; }
FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const {
return free_variables_direct_->count_nodes_map_;
}
FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const {
return func_graph_cnodes_index_->count_nodes_map_;
}
FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }
FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; }
FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; }
FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const {
return func_graph_child_direct_->count_func_graphs_map_; return func_graph_child_direct_->count_func_graphs_map_;
} }
@ -598,10 +613,8 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
std::shared_ptr<NodesCollector> nodes_; std::shared_ptr<NodesCollector> nodes_;
std::shared_ptr<ValueNodesCollector> valuenodes_; std::shared_ptr<ValueNodesCollector> valuenodes_;
std::shared_ptr<FVDirectCollector> free_variables_direct_; std::shared_ptr<FVDirectCollector> free_variables_direct_;
std::shared_ptr<FuncGraphValueNodesCollector> func_graph_valuenodes_; std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_;
std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_; std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_;
std::shared_ptr<FuncGraphUsersCollector> func_graph_users_;
std::shared_ptr<FuncGraphUserNodesCollector> func_graph_user_cnodes_;
std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_; std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_;
std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_; std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_;
std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_; std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_;

@ -81,10 +81,10 @@ bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) {
} }
bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
auto &users = fg->func_graph_users(); auto &cnodes = fg->func_graph_cnodes_index();
int n_use = int n_use =
std::accumulate(users.begin(), users.end(), 0, std::accumulate(cnodes.begin(), cnodes.end(), 0,
[](int sum, const std::pair<const FuncGraphPtr, int> &item) { return sum + item.second; }); [](int sum, const std::pair<const CNodeIndexPairPtr, int> &item) { return sum + item.second; });
return n_use == 1; return n_use == 1;
} }

@ -486,7 +486,8 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
} }
void TraverseGraphMap( void TraverseGraphMap(
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts, const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr,
const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts,
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(manager_ptr);
MS_EXCEPTION_IF_NULL(tr); MS_EXCEPTION_IF_NULL(tr);

Loading…
Cancel
Save