diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc index c1bf40fcbb..34c46c6b66 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -52,9 +52,10 @@ std::string GetNodeRepr(AnfNodePtr node) { void ResolveFuncGraph_(const FuncGraphPtr &fg) { auto manager = Manage(fg, false); + auto use_sig = parse::python_adapter::UseSignatureInResolve(); parse::python_adapter::set_use_signature_in_resolve(false); parse::ResolveAll(manager); - parse::python_adapter::set_use_signature_in_resolve(true); + parse::python_adapter::set_use_signature_in_resolve(use_sig); } bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index b52dddda66..14e9f739d5 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -145,6 +145,12 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { std::string var = phi_nodes_[phi]; MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; + auto removable = CollectRemovablePhi(phi); + // If the phi node is not necessary, not need to add to jumps_ of the prev blocks. + if (removable) { + MS_LOG(DEBUG) << "remove the phi when call graph " << func_graph_->ToString() << " var " << var; + return; + } for (auto &pred : prev_blocks_) { MS_EXCEPTION_IF_NULL(pred); MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); @@ -152,16 +158,6 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { CNodePtr jump = pred->jumps_[this]; jump->add_input(arg_node); } - // If the phi node in the body part of a for/while loop is being removed, - // then the closure convert phase will generate a cycle in graph if the - // loop is kept after specialization. This should be investigate further. - // Just now user has to set a flag on a function to indicate the for loop - // will definitely can be unroll as the sequence in for statement is fixed - // size in compile time. - if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) || - parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) { - CollectRemovablePhi(phi); - } } AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { @@ -207,13 +203,13 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const Parame // 2. it's costly to iterate the graph to replace the phi for each phi. // Args : // phi : This parameter node is functioning as a phi node. -void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { +bool FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { MS_EXCEPTION_IF_NULL(phi); std::string var = phi_nodes_[phi]; - MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); + MS_LOG(DEBUG) << "check phi " << phi->DebugString() << " for " << var; if (prev_blocks_.size() == 0) { - MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString(); - return; + MS_LOG(DEBUG) << "no phi " << phi->DebugString() << " for var " << var; + return false; } AnfNodePtr arg_node = SearchReplaceNode(var, phi); if (arg_node != nullptr) { @@ -235,13 +231,16 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { const auto ¶m = phi_iter.second->cast(); if (param == phi) { MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() - << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); + << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString() + << " in graph " << arg_node->func_graph()->ToString(); prev->removable_phis_[phi_iter.first] = arg_node; } } } } + return true; } + return false; } // A block should be marked matured if its predecessor blocks have been processed diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.h b/mindspore/ccsrc/pipeline/jit/parse/function_block.h index cbf75a3dd8..2331eeca47 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.h @@ -52,7 +52,7 @@ class FunctionBlock : public std::enable_shared_from_this { AnfNodePtr ReadVariable(const std::string &var_name); void AddPrevBlock(const FunctionBlockPtr &block); void SetPhiArgument(const ParameterPtr &phi); - void CollectRemovablePhi(const ParameterPtr &phi); + bool CollectRemovablePhi(const ParameterPtr &phi); // A block is matured if all its predecessors is generated void Mature(); CNodePtr ForceToBoolNode(const AnfNodePtr &cond); diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index edc9a66594..b2e95c5070 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -1436,6 +1436,15 @@ FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::obje return block; } +AnfNodePtr FindPhis(const std::unordered_map &removable_phis, const AnfNodePtr &node) { + const auto &inp = node->cast(); + const auto &iter = removable_phis.find(inp); + if (iter == removable_phis.end()) { + return node; + } + return FindPhis(removable_phis, iter->second); +} + void Parser::RemoveUnnecessaryPhis() { // merge all removable phis to one map; std::unordered_map removable_phis; @@ -1443,28 +1452,39 @@ void Parser::RemoveUnnecessaryPhis() { MS_EXCEPTION_IF_NULL(block); removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); } - if (removable_phis.size() == 0) { return; } - for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) { - if (node->isa()) { - const auto &cnode = node->cast(); - auto &inputs = cnode->inputs(); - for (std::size_t i = 0; i < inputs.size(); i++) { - if (inputs[i]->isa()) { - const auto &inp = inputs[i]->cast(); - const auto &iter = removable_phis.find(inp); - if (iter == removable_phis.end()) { - continue; - } - auto &argNode = iter->second; - MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in " - << cnode->DebugString() << " with " << argNode->DebugString(); - cnode->set_input(i, argNode); - } - } + + auto fg_name = func_graph_->ToString(); + auto mng = Manage(func_graph_, false); + // replace the nodes + for (auto iter : removable_phis) { + auto new_node = FindPhis(removable_phis, iter.first); + MS_LOG(DEBUG) << "phi " << iter.first->DebugString() << " to " << new_node->DebugString(); + mng->Replace(iter.first, new_node); + } + // remove the parameter + for (FunctionBlockPtr &block : func_block_list_) { + MS_EXCEPTION_IF_NULL(block); + auto &local_removable_phis = block->removable_phis(); + if (local_removable_phis.size() == 0) { + continue; } + auto func_graph = block->func_graph(); + auto ¶meters = func_graph->parameters(); + std::vector new_parameters(parameters.size()); + auto it = std::copy_if( + parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](AnfNodePtr param) { + return local_removable_phis.find(param->cast()) == local_removable_phis.end(); + }); + + // shrink container to new size + new_parameters.resize(std::distance(new_parameters.begin(), it)); + func_graph->set_parameters(new_parameters); + } + for (auto fg : mng->func_graphs()) { + fg->ClearAllManagerInfo(); } } diff --git a/mindspore/ccsrc/utils/graph_utils.cc b/mindspore/ccsrc/utils/graph_utils.cc index 03ac14573d..6689719fcc 100644 --- a/mindspore/ccsrc/utils/graph_utils.cc +++ b/mindspore/ccsrc/utils/graph_utils.cc @@ -111,6 +111,27 @@ std::vector BroadFirstSearchGraphCNodes(CNodePtr ret) { return sorted_nodes; } +std::vector BroadFirstSearchGraphUsed(FuncGraphPtr root) { + std::deque todo; + todo.push_back(root); + std::vector sorted; + auto seen = NewSeenGeneration(); + while (!todo.empty()) { + FuncGraphPtr top = todo.front(); + todo.pop_front(); + sorted.push_back(top); + auto used = top->func_graphs_used(); + for (auto &item : used) { + if (item.first->seen_ == seen) { + continue; + } + todo.push_back(item.first); + item.first->seen_ = seen; + } + } + return sorted; +} + std::vector SuccDeeper(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { diff --git a/mindspore/ccsrc/utils/graph_utils.h b/mindspore/ccsrc/utils/graph_utils.h index 2a9240ac84..8eb75f6799 100644 --- a/mindspore/ccsrc/utils/graph_utils.h +++ b/mindspore/ccsrc/utils/graph_utils.h @@ -70,6 +70,7 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = const IncludeFunc &include = AlwaysInclude); std::vector BroadFirstSearchGraphCNodes(CNodePtr ret); +std::vector BroadFirstSearchGraphUsed(FuncGraphPtr root); class FuncGraphIndex { public: explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 0d96ddf263..275bd3b206 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -77,6 +77,17 @@ std::string CNode::DebugString(int recursive_level) const { return buffer.str(); } +std::string Parameter::DebugString(int recursive_level) const { + std::ostringstream buffer; + if (recursive_level > 0) { + if (func_graph() != nullptr) { + buffer << func_graph()->ToString() << ":"; + } + } + buffer << ToString(); + return buffer.str(); +} + std::string ValueNode::ToString() const { MS_EXCEPTION_IF_NULL(value_); if (value_->isa()) { diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index c1a28d57f1..961dcde8a7 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -249,7 +249,7 @@ class Parameter : public ANode { MS_DECLARE_PARENT(Parameter, ANode); void accept(AnfVisitor *v) override; - + std::string DebugString(int recursive_level = 1) const override; std::string name() const { return name_; } void set_name(const std::string &name) { name_ = name; } std::string fullname_with_scope() override { return name(); }; diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index fabdd3e7d3..570ed61f96 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -417,6 +417,15 @@ std::shared_ptr> FuncGraph::recursive_graphs() { return mng->recursive_graphs(shared_from_base()); } +void FuncGraph::ClearAllManagerInfo() { + ClearNodes(); + ClearValueNodes(); + ClearFuncGraphCNodesIndex(); + ClearFreeVariables(); + ClearFuncGraphsUsed(); + ClearJFuncGraphs(); +} + AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { auto itr = this->parameter_default_value_.find(name); if (itr == parameter_default_value_.end()) { diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 712c75b431..fd7f5d9d48 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -229,7 +229,8 @@ class FuncGraph : public FuncGraphBase { } this->debug_info_ = info; } - + // clear all info from manager + void ClearAllManagerInfo(); // get all nodes belonging to this func graph const AnfNodeSet &nodes(); void CopyNodes(const FuncGraphPtr &source); diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index 0857770cad..432a924b1e 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -25,6 +25,7 @@ #include "utils/log_adapter.h" #include "utils/profile.h" #include "utils/context/ms_context.h" +#include "utils/graph_utils.h" // namespace to support intermediate representation definition namespace mindspore { @@ -400,11 +401,16 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph } void Cloner::Lift() { - for (auto &func_graph_params : repl_func_graph_params_) { - auto &func_graph = func_graph_params.first; - auto ¶ms = func_graph_params.second; - for (auto &cnode : func_graph->func_graph_cnodes_index()) { - LiftParameters(cnode.first->first->func_graph(), func_graph, params); + // lift inner graph first + auto sorted = BroadFirstSearchGraphUsed(*(manager_->roots().begin())); + for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) { + auto func_graph = *r_iter; + auto iter = repl_func_graph_params_.find(func_graph); + if (iter != repl_func_graph_params_.end()) { + auto ¶ms = iter->second; + for (auto &cnode : func_graph->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph, params); + } } } } diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc index 00c39679cd..5c996bcdab 100644 --- a/mindspore/core/ir/manager.cc +++ b/mindspore/core/ir/manager.cc @@ -520,12 +520,7 @@ void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { target->CopyFuncGraphsUsed(source); target->CopyJFuncGraphs(source); signals_->InvalidateComputer(); - source->ClearNodes(); - source->ClearValueNodes(); - source->ClearFuncGraphCNodesIndex(); - source->ClearFreeVariables(); - source->ClearFuncGraphsUsed(); - source->ClearJFuncGraphs(); + source->ClearAllManagerInfo(); } FuncGraphTransaction FuncGraphManager::Transact() { diff --git a/tests/ut/cpp/common/py_func_graph_fetcher.h b/tests/ut/cpp/common/py_func_graph_fetcher.h index d864842760..ae9467cef1 100644 --- a/tests/ut/cpp/common/py_func_graph_fetcher.h +++ b/tests/ut/cpp/common/py_func_graph_fetcher.h @@ -72,6 +72,7 @@ class PyFuncGraphFetcher { mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn); if (doResolve_) { std::shared_ptr manager = mindspore::Manage(func_graph, false); + mindspore::parse::python_adapter::set_use_signature_in_resolve(false); mindspore::parse::ResolveAll(manager); } return func_graph; diff --git a/tests/ut/python/pynative_mode/test_multigraph_sink.py b/tests/ut/python/pynative_mode/test_multigraph_sink.py index e8ebe03797..c4ef44ef5a 100644 --- a/tests/ut/python/pynative_mode/test_multigraph_sink.py +++ b/tests/ut/python/pynative_mode/test_multigraph_sink.py @@ -131,3 +131,26 @@ def test_while_in_while(): output = while_in_while(c1, c2, c3) expect = Tensor([1274], mstype.int32) assert output == expect + + +@ms_function +def while_by_while_in_while(x, y, z): + out = c4 + while x < c2: + y = c4 + c4 + while y < c2: + y = y + 1 + out = out + y + z = c4 + c4 + while z < c2: + z = z + 1 + out = out + z + x = x + 1 + out = out + x + return out + + +def test_while_by_while_in_while(): + output = while_by_while_in_while(c1, c2, c3) + expect = Tensor([350], mstype.int32) + assert output == expect