From 077bde0767dd2b11b322332ef97cea03083115bf Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Thu, 22 Oct 2020 19:23:06 +0800 Subject: [PATCH] Return a new abstract without tracking_id for fg ValueNode in CSE. --- mindspore/ccsrc/frontend/optimizer/cse.cc | 16 ++++++++++------ mindspore/ccsrc/frontend/optimizer/cse.h | 2 +- mindspore/ccsrc/frontend/optimizer/cse_pass.h | 2 -- .../jit/static_analysis/static_analysis.cc | 1 - mindspore/core/abstract/abstract_function.h | 2 ++ 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc index 350e9fa7ed..621dc3e1ce 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.cc +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -32,19 +32,23 @@ using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunctionPtr; -BasePtr AbsOf(const AnfNodePtr &node) { +BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) { MS_EXCEPTION_IF_NULL(node); auto node_abs = node->abstract(); - // in testcase: TestOptOpt.CSE, node->abstract() is null; + // In testcase: TestOptOpt.CSE, node->abstract() is null. if (node_abs == nullptr) { return kAnyValue; } - // Ignore the tracking_id and prim pointer hash; if (node_abs->isa()) { + // Ignore the tracking_id and prim pointer hash. auto prim_abs = node_abs->cast(); return prim_abs->prim(); + } else if (ignore_fg_abs_tracking_id && node_abs->isa()) { + // Ignore the tracking_id. + auto new_fg_abs = node_abs->cast()->Copy(); + new_fg_abs->set_tracking_id(nullptr); + return new_fg_abs; } - return node_abs; } @@ -68,7 +72,7 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { ValueNodePtr value_node = node->cast(); auto value = value_node->value(); MS_EXCEPTION_IF_NULL(value); - h = hash_combine(value->hash(), (AbsOf(value_node)->hash())); + h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash())); } else if (node->isa()) { auto cnode = node->cast(); auto &inputs = cnode->inputs(); @@ -134,7 +138,7 @@ bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool chec if (main->isa() && node->isa()) { auto main_value = GetValueNode(main); auto node_value = GetValueNode(node); - return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + return (AbsOf(main, true) == AbsOf(node, true)) && (*main_value == *node_value); } else if (main->isa() && node->isa()) { auto c_main = main->cast(); auto c_node = node->cast(); diff --git a/mindspore/ccsrc/frontend/optimizer/cse.h b/mindspore/ccsrc/frontend/optimizer/cse.h index b35004c593..abfcd635e9 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.h +++ b/mindspore/ccsrc/frontend/optimizer/cse.h @@ -46,7 +46,7 @@ class CSE { std::unordered_map> *groups) const; }; -BasePtr AbsOf(const AnfNodePtr &node); +BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id = false); } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/cse_pass.h b/mindspore/ccsrc/frontend/optimizer/cse_pass.h index 86b80e04d3..2e59e8356f 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse_pass.h +++ b/mindspore/ccsrc/frontend/optimizer/cse_pass.h @@ -44,8 +44,6 @@ class CSEPass : public CSE { private: bool report_changes_; }; - -BasePtr AbsOf(const AnfNodePtr &node); } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 163f7aeb3d..eb0031a783 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -467,7 +467,6 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) { } else { MS_LOG(EXCEPTION) << "Cannot GetEvaluator from AbstractFunction"; } - return nullptr; } EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { diff --git a/mindspore/core/abstract/abstract_function.h b/mindspore/core/abstract/abstract_function.h index 79a1d6c1d7..c0eea99d40 100644 --- a/mindspore/core/abstract/abstract_function.h +++ b/mindspore/core/abstract/abstract_function.h @@ -113,6 +113,8 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { AnfNodePtr tracking_id() const override { return tracking_id_.lock(); } + void set_tracking_id(AnfNodePtr node) override { tracking_id_ = AnfNodeWeakPtr(node); } + AbstractFunctionPtr Copy() const override { return std::make_shared(func_graph_, context_, tracking_id()); }