Optimize the collectors of manager which listen to the graphs and nodes changes.

1. Remove the records of user graphs;
2. Remove the records of user value nodes;
3. Remove the records of user cnodes;
4. Add the records of users, and the API to access the users of graph, value node, and cnode;
5. Fix issue:User cnode record may point to its own graph, when combine the user(caller) and used one(callee);
6. Fix issue:User graphs never update itself after its first creation.
pull/847/head
Zhang Qinghua 5 years ago
parent 3dd369cefa
commit d43ad79b50

@ -263,18 +263,15 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
return used;
}
const FuncGraphCounterMap &FuncGraph::func_graph_users() {
auto mng = manager_.lock();
MS_EXCEPTION_IF_NULL(mng);
auto &users = mng->func_graph_users();
return users[shared_from_base<FuncGraph>()];
}
const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() {
const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() {
auto mng = manager_.lock();
if (mng == nullptr) {
MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
<< " NodeInfo: " << trace::GetDebugInfo(debug_info());
}
MS_EXCEPTION_IF_NULL(mng);
auto &users = mng->func_graph_user_cnodes();
return users[shared_from_base<FuncGraph>()];
auto &cnode = mng->func_graph_cnodes_index();
return cnode[shared_from_base<FuncGraph>()];
}
FuncGraphPtr FuncGraph::parent() {

@ -37,6 +37,7 @@ namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>;
using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>;
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
@ -203,11 +204,8 @@ class FuncGraph : public FuncGraphBase {
// get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total();
// get all users of this func graph
const FuncGraphCounterMap &func_graph_users();
// get all user cnodes of this func graph
const AnfNodeCounterMap &func_graph_user_cnodes();
// get all user value nodes of this func graph
const CNodeIndexCounterMap &func_graph_cnodes_index();
// Return the parent of this graph.
FuncGraphPtr parent();

@ -182,9 +182,11 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
}
target_func_graph->set_return(return_node);
auto &value_nodes = manager_->func_graph_valuenodes()[func_graph];
for (auto &value_node : value_nodes) {
CloneValueNode(value_node.first, target_func_graph);
auto &cnodes = manager_->func_graph_cnodes_index()[func_graph];
for (auto &cnode : cnodes) {
auto parent = cnode.first->first->cast<CNodePtr>();
auto valuenode = parent->input(cnode.first->second);
CloneValueNode(valuenode, target_func_graph);
}
}
@ -386,8 +388,8 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph
if (lift_params.empty()) {
return;
}
for (auto &user : func_graph_user->func_graph_users()) {
LiftParameters(user.first, func_graph_user, lift_params);
for (auto &cnode : func_graph_user->func_graph_cnodes_index()) {
LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params);
}
}
@ -395,8 +397,8 @@ void Cloner::Lift() {
for (auto &func_graph_params : repl_func_graph_params_) {
auto &func_graph = func_graph_params.first;
auto &params = func_graph_params.second;
for (auto &user : func_graph->func_graph_users()) {
LiftParameters(user.first, func_graph, params);
for (auto &cnode : func_graph->func_graph_cnodes_index()) {
LiftParameters(cnode.first->first->func_graph(), func_graph, params);
}
}
}

@ -78,13 +78,16 @@ void FuncGraphManager::Reset() {
node_users_ = NodeUsersMap();
signals_ = std::make_shared<Signals>();
// FuncGraph --> AnfNode
nodes_ = std::make_shared<NodesCollector>(this);
// FuncGraph --> {AnfNode, Count}
valuenodes_ = std::make_shared<ValueNodesCollector>(this);
free_variables_direct_ = std::make_shared<FVDirectCollector>(this);
func_graph_valuenodes_ = std::make_shared<FuncGraphValueNodesCollector>(this);
func_graph_cnodes_index_ = std::make_shared<FuncGraphUsersCNodeIndexCollector>(this);
// FuncGraph --> {FuncGraph, Count}
func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this);
func_graph_users_ = std::make_shared<FuncGraphUsersCollector>(this);
func_graph_user_cnodes_ = std::make_shared<FuncGraphUserNodesCollector>(this);
func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this);
func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this);
func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this);
@ -300,9 +303,9 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
continue;
}
MS_EXCEPTION_IF_NULL(func_graph_users_);
auto &users = func_graph_users_->count_func_graphs_map()[func_graph];
if (!users.empty() && !ignore_users) {
MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_);
auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph];
if (!users_cnode_index.empty() && !ignore_users) {
MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
continue;
}
@ -472,10 +475,6 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t
node->set_scope(scope);
}
}
for (auto &used : source->func_graphs_used()) {
(void)func_graph_users_->Inc(used.first, target, used.second);
(void)this->func_graph_users()[used.first].erase(source);
}
for (auto &child : this->func_graph_child_direct()[source]) {
(void)func_graph_parents_direct_->Inc(child.first, target, child.second);
(void)this->func_graph_parents_direct()[child.first].erase(source);
@ -661,7 +660,9 @@ DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAna
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Inc(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
auto &d = count_nodes_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
@ -672,7 +673,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodeP
return false;
}
bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Dec(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
MS_EXCEPTION_IF_NULL(func_graph);
auto &d = count_nodes_map_[func_graph];
if (d.count(key) != 0) {
@ -682,7 +685,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP
} else {
d[key] -= count;
if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
MS_LOG(EXCEPTION) << "Count of key '" << key
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
@ -690,52 +693,15 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP
return false;
}
bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) {
template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
return Dec(func_graph, key, -count);
} else {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
} else {
d[key] += count;
}
return false;
}
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
return true;
} else {
d[key] -= count;
if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
}
return false;
}
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
return Dec(func_graph, key, -count);
} else {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
MS_LOG(EXCEPTION) << "Count of key '" << key
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
@ -754,16 +720,21 @@ void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
(void)count_nodes_map_.erase(src);
}
// if inp is a graph ValueNode, this graph's FuncGraphValueNodesCollector's value is inp self
void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, EdgeProcessDirection direction) {
void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp,
EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(GetValueNode<FuncGraphPtr>(inp), inp, direction);
(void)Mod(GetValueNode<FuncGraphPtr>(inp), std::make_shared<CNodeIndexPair>(std::make_pair(node, index)),
direction);
}
}
void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second);
// Ignore the user graph who may own itself.
if (dst != it.first->first->func_graph()) {
(void)Inc(dst, it.first, it.second);
}
}
(void)count_nodes_map_.erase(src);
}
@ -794,6 +765,45 @@ static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
return gn;
}
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
} else {
d[key] += count;
}
return false;
}
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
return true;
} else {
d[key] -= count;
if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
}
return false;
}
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
return Dec(func_graph, key, -count);
} else {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(inp);
@ -859,32 +869,6 @@ void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst)
(void)count_func_graphs_map_.erase(src);
}
void FuncGraphUsersCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(GetValueNode<FuncGraphPtr>(inp), node->func_graph(), direction);
}
}
void FuncGraphUsersCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr) {
// all graph use in src need to change to dst, so add dst user
(void)count_func_graphs_map_.erase(src);
}
void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(GetValueNode<FuncGraphPtr>(inp), node, direction);
}
}
void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_nodes_map_.erase(src);
}
void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) {
(void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);

@ -100,8 +100,12 @@ struct Signals {
enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 };
using CNodeIndexPair = std::pair<AnfNodePtr, int>;
using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>;
using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int>>;
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<ValueT, int, CollectorHash, CollectorEqual>>;
// analysis base class
class FuncGraphAnalysis {
@ -174,46 +178,56 @@ class NodesCollector final : public DepCollector {
void OnDropNode(AnfNodePtr n) override;
};
class CounterFuncGraphCollector : public DepCollector {
public:
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterFuncGraphCollector() override = default;
FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; }
// inherit from FuncGraphAnalysis
size_t size() const override { return count_func_graphs_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
struct CNodeIndexHasher {
std::size_t operator()(const CNodeIndexPairPtr pair) const {
MS_EXCEPTION_IF_NULL(pair);
MS_EXCEPTION_IF_NULL(pair->first);
return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
}
};
protected:
void ExtraReset() override { count_func_graphs_map_.clear(); }
struct CNodeIndexEqual {
bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
if (lhs == nullptr || rhs == nullptr) {
return false;
}
if (lhs == rhs) {
return true;
}
if (lhs->first != rhs->first) {
return false;
}
if (lhs->second != rhs->second) {
return false;
}
return true;
}
};
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
class CounterAnfNodeCollector : public DepCollector {
public:
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; }
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); }
void OnAddFuncGraph(FuncGraphPtr fg) final {
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
}
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
FuncGraphToAnfNodeCounterMap count_nodes_map_;
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
protected:
void ExtraReset() override { count_nodes_map_.clear(); }
};
class ValueNodesCollector final : public CounterAnfNodeCollector {
class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
public:
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~ValueNodesCollector() override = default;
@ -223,17 +237,19 @@ class ValueNodesCollector final : public CounterAnfNodeCollector {
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector {
// Record the CNode and its input index, who points to the function graph.
class FuncGraphUsersCNodeIndexCollector final
: public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> {
public:
explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FuncGraphValueNodesCollector() override = default;
explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FuncGraphUsersCNodeIndexCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
class FVDirectCollector final : public CounterAnfNodeCollector {
class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
public:
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FVDirectCollector() override = default;
@ -243,6 +259,25 @@ class FVDirectCollector final : public CounterAnfNodeCollector {
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
class CounterFuncGraphCollector : public DepCollector {
public:
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterFuncGraphCollector() override = default;
FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; }
// inherit from FuncGraphAnalysis
size_t size() const override { return count_func_graphs_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
protected:
void ExtraReset() override { count_func_graphs_map_.clear(); }
};
class FuncGraphChildDirect final : public CounterFuncGraphCollector {
public:
explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
@ -279,28 +314,6 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
// graph's all user graphs: key is g, value is graphs who used g
class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphUsersCollector() override = default;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
// graph's all user cnodes: key is g, value is cnodes who used g
class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
public:
explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphUserNodesCollector() override = default;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
@ -433,7 +446,9 @@ class ScopeComputer final : public DepComputer {
using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;
class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector {
class FVTotalComputer final : public DepComputer,
public CounterAnfNodeCollector<AnfNodePtr>,
public CounterFuncGraphCollector {
public:
explicit FVTotalComputer(const FuncGraphManager *m)
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
@ -549,18 +564,18 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; }
FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const {
return free_variables_direct_->count_nodes_map_;
}
FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; }
FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const {
return func_graph_cnodes_index_->count_nodes_map_;
}
FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }
FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; }
FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; }
FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const {
return func_graph_child_direct_->count_func_graphs_map_;
}
@ -598,10 +613,8 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
std::shared_ptr<NodesCollector> nodes_;
std::shared_ptr<ValueNodesCollector> valuenodes_;
std::shared_ptr<FVDirectCollector> free_variables_direct_;
std::shared_ptr<FuncGraphValueNodesCollector> func_graph_valuenodes_;
std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_;
std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_;
std::shared_ptr<FuncGraphUsersCollector> func_graph_users_;
std::shared_ptr<FuncGraphUserNodesCollector> func_graph_user_cnodes_;
std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_;
std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_;
std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_;

@ -81,10 +81,10 @@ bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) {
}
bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
auto &users = fg->func_graph_users();
auto &cnodes = fg->func_graph_cnodes_index();
int n_use =
std::accumulate(users.begin(), users.end(), 0,
[](int sum, const std::pair<const FuncGraphPtr, int> &item) { return sum + item.second; });
std::accumulate(cnodes.begin(), cnodes.end(), 0,
[](int sum, const std::pair<const CNodeIndexPairPtr, int> &item) { return sum + item.second; });
return n_use == 1;
}

@ -486,7 +486,8 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
}
void TraverseGraphMap(
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts,
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr,
const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts,
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
MS_EXCEPTION_IF_NULL(manager_ptr);
MS_EXCEPTION_IF_NULL(tr);

Loading…
Cancel
Save