diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index e6addae76e..5e893cf1aa 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -92,18 +92,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode return result; } -static bool inline isTraversable(const AnfNodePtr &node, const AnfNodeSet &all_nodes) { - if (node->isa() || node->isa()) { +static bool isTraversable(const AnfNodePtr &node) { + if (node == nullptr) { return false; } - + if (node->isa() || node->isa()) { + return true; + } if (IsValueNode(node) || IsValueNode(node)) { - if (!all_nodes.contains(node)) { - return false; - } return true; } - return false; } @@ -126,15 +124,9 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo todo.pop_front(); // check whether this node has been matched. - if (node == nullptr || node->seen_ == seen) { - continue; - } - - auto fg = node->func_graph(); - if (!(fg != nullptr && fg->manager() != nullptr) && !isTraversable(node, all_nodes)) { + if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { continue; } - node->seen_ = seen; // select nodes that this transform can be applied. diff --git a/tests/ut/cpp/common/py_func_graph_fetcher.h b/tests/ut/cpp/common/py_func_graph_fetcher.h index 9d374fcd60..98552a96b5 100644 --- a/tests/ut/cpp/common/py_func_graph_fetcher.h +++ b/tests/ut/cpp/common/py_func_graph_fetcher.h @@ -22,7 +22,6 @@ #include "ir/primitive.h" #include "ir/manager.h" #include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" #include "pipeline/parse/parse_base.h" #include "pipeline/parse/parse.h" #include "./common.h" @@ -48,10 +47,9 @@ class PyFuncGraphFetcher { py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...); mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); if (doResolve_) { - std::shared_ptr manager = mindspore::Manage(func_graph, true); + std::shared_ptr manager = mindspore::Manage(func_graph, false); mindspore::parse::python_adapter::set_use_signature_in_resolve(false); mindspore::parse::ResolveAll(manager); - func_graph = BasicClone(func_graph); } return func_graph; } catch (py::error_already_set& e) { @@ -73,9 +71,8 @@ class PyFuncGraphFetcher { py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str()); mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); if (doResolve_) { - std::shared_ptr manager = mindspore::Manage(func_graph, true); + std::shared_ptr manager = mindspore::Manage(func_graph, false); mindspore::parse::ResolveAll(manager); - func_graph = BasicClone(func_graph); } return func_graph; } catch (py::error_already_set& e) {