Pre Merge pull request !14592 from zhangzhaoju/ms_master_weak_ptr

pull/14592/MERGE
zhangzhaoju 4 years ago committed by Gitee
commit 94c2b3ffdd

@ -829,7 +829,10 @@ static std::vector<std::pair<CNodePtr, CNodePtr>> FindPrimalJPair(const FuncGrap
auto &node_user_map = manager->node_users();
// Search primal graph user cnodes.
for (auto &entry : primal_graph->func_graph_cnodes_index()) {
auto cnode = entry.first->first->cast<CNodePtr>();
auto cnode = dyn_cast<CNode>(entry.first->first.lock());
if (cnode == nullptr) {
continue;
}
auto index = entry.first->second;
if (index == 0) {
// To find real calling.

@ -758,7 +758,8 @@ class SideEffectFinder {
continue;
}
// Caller cnode.
auto cnode = dyn_cast<CNode>(user.first->first);
auto cnode = dyn_cast<CNode>(user.first->first.lock());
// if cnode was free, ignore the using info
if (cnode && input_index < cnode->size()) {
handler(cnode->input(input_index));
}

@ -359,7 +359,12 @@ void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
auto &others = source->func_graph_cnodes_index();
for (auto it = others.begin(); it != others.end(); it++) {
// Ignore the user graph who may own itself.
auto fg = it->first->first->func_graph();
auto anf_node = it->first->first.lock();
// cnode was free, so ignore the using info
if (anf_node == nullptr) {
continue;
}
auto fg = anf_node->func_graph();
MS_EXCEPTION_IF_NULL(fg);
if (fg.get() != this) {
AddFuncGraphCNodeIndex(it->first, it->second);
@ -384,7 +389,7 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
} else {
func_graph_cnodes_index_[pair]--;
if (func_graph_cnodes_index_[pair] < 0) {
MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second
MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first.lock() << "/" << pair->second
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}

@ -42,12 +42,15 @@
namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
using CNodeIndexPair = std::pair<AnfNodeWeakPtr, int>;
using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
struct CNodeIndexHasher {
std::size_t operator()(const CNodeIndexPairPtr pair) const {
MS_EXCEPTION_IF_NULL(pair);
MS_EXCEPTION_IF_NULL(pair->first);
return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
auto node = pair->first.lock();
MS_EXCEPTION_IF_NULL(node);
return hash_combine(node->hash(), std::hash<int>()(pair->second));
}
};
@ -59,7 +62,7 @@ struct CNodeIndexEqual {
if (lhs == rhs) {
return true;
}
if (lhs->first != rhs->first) {
if (lhs->first.lock() != rhs->first.lock()) {
return false;
}
if (lhs->second != rhs->second) {

@ -201,7 +201,11 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
auto &cnodes = func_graph->func_graph_cnodes_index();
for (auto &cnode : cnodes) {
auto parent = cnode.first->first->cast<CNodePtr>();
auto parent = dyn_cast<CNode>(cnode.first->first.lock());
// cnode was free, so ignore the using info
if (parent == nullptr) {
continue;
}
auto valuenode = parent->input(cnode.first->second);
CloneValueNode(valuenode, target_func_graph);
}
@ -415,7 +419,12 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph
return;
}
for (auto &cnode : func_graph_user->func_graph_cnodes_index()) {
LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params);
auto anf_node = cnode.first->first.lock();
// cnode was free, so ignore the using info
if (anf_node == nullptr) {
continue;
}
LiftParameters(anf_node->func_graph(), func_graph_user, lift_params);
}
}
@ -428,7 +437,12 @@ void Cloner::Lift() {
if (iter != repl_func_graph_params_.end()) {
auto &params = iter->second;
for (auto &cnode : func_graph->func_graph_cnodes_index()) {
LiftParameters(cnode.first->first->func_graph(), func_graph, params);
auto anf_node = cnode.first->first.lock();
// cnode was free, so ignore the using info
if (anf_node == nullptr) {
continue;
}
LiftParameters(anf_node->func_graph(), func_graph, params);
}
}
}

@ -92,9 +92,6 @@ struct Signals {
};
enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 };
using CNodeIndexPair = std::pair<AnfNodePtr, int>;
using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
// analysis base class, graphs analysis which need dynamic compute by DepCollector in each read

Loading…
Cancel
Save