From 81259d733969adff39317ddb81fbe14646ab2d2b Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Fri, 19 Feb 2021 03:14:00 +0800 Subject: [PATCH] fix load input is depend node when ReplaceUpdateStateForLoad --- mindspore/ccsrc/frontend/optimizer/cse.cc | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc index d227386212..c8305219db 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.cc +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -133,7 +133,8 @@ std::vector> GenerateLoadGroups(const FuncGraphPtr &fg, cons } if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { for (const auto &input : cnode->inputs()) { - if (input->isa()) { + if (input->isa() || + (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast()->input(1)->isa())) { unload_users_record.insert(input); } } @@ -317,7 +318,10 @@ AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) { // To: // u1 = UpdateState(u, c) // p1 = Load(para1, u') // u' is first monad in graph or new monad -void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector &need_replace_loads) { +bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector &need_replace_loads) { + if (need_replace_loads.size() == 0) { + return false; + } constexpr size_t second_input_index = 2; auto monad = GetFirstMonad(fg); for (const auto &load_node : need_replace_loads) { @@ -331,6 +335,7 @@ void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vectormanager(); mgr->SetEdge(load_node, second_input_index, monad); } + return true; } // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => @@ -341,7 +346,10 @@ bool CSE::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const { std::vector toposet = TopoSort(fg->get_return()); std::vector need_replace_loads; std::vector> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); - ReplaceUpdateStateForLoad(fg, need_replace_loads); + const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads); + if (update_state_replaced) { + changed = true; + } // split group if there is no-load node between two load nodes. std::vector> need_merge_loads; for (auto &group : load_groups) {