From daecc79e918249b7f7a7e77cd474d39afb324755 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Tue, 30 Mar 2021 11:52:22 +0800 Subject: [PATCH] fix load eliminate bug --- .../optimizer/irpass/load_eliminate.cc | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.cc index ec962cea25..74e7454f1b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/load_eliminate.cc @@ -24,6 +24,18 @@ #include "frontend/operator/ops.h" namespace mindspore::opt::irpass { +// Covert: +// load1 = load(para1, u1) +// u2 = UpdateState(u1, load1) +// ... +// load2 = load(load1, u3) +// u4 = UpdateState(u3, load2) +// To: +// load1 = load(para1, u1) +// u2 = UpdateState(u1, load1) +// ... +// load2 = load(para1, u3) # load1 replaced by para1 +// u4 = UpdateState(u3, load2) AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { auto load_node = dyn_cast(node); if (load_node == nullptr || load_node->inputs().empty()) { @@ -32,8 +44,20 @@ AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &no } auto load_cnode = load_node->cast(); constexpr size_t kFirstInputIndex = 1; - if (IsPrimitiveCNode(load_cnode->input(kFirstInputIndex), prim::kPrimLoad)) { - return load_cnode->input(kFirstInputIndex); + constexpr size_t kSecondInputIndex = 2; + auto &input_load = load_cnode->input(kFirstInputIndex); + if (IsPrimitiveCNode(input_load, prim::kPrimLoad)) { + auto load_prim = NewValueNode(prim::kPrimLoad); + auto input_load_cnode = input_load->cast(); + auto replace_input = input_load_cnode->input(kFirstInputIndex); + auto monad = load_cnode->input(kSecondInputIndex); + std::vector new_load_inputs = {load_prim, replace_input, monad}; + auto fg = load_cnode->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto new_load = fg->NewCNode(new_load_inputs); + new_load->set_abstract(load_cnode->abstract()); + new_load->set_scope(load_cnode->scope()); + return new_load; } return nullptr; }