fix load input is depend node when ReplaceUpdateStateForLoad

pull/12410/head
Margaret_wangrui 4 years ago
parent 9f05cc1351
commit 81259d7339

@ -133,7 +133,8 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
} }
if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
for (const auto &input : cnode->inputs()) { for (const auto &input : cnode->inputs()) {
if (input->isa<Parameter>()) { if (input->isa<Parameter>() ||
(IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>())) {
unload_users_record.insert(input); unload_users_record.insert(input);
} }
} }
@ -317,7 +318,10 @@ AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
// To: // To:
// u1 = UpdateState(u, c) // u1 = UpdateState(u, c)
// p1 = Load(para1, u') // u' is first monad in graph or new monad // p1 = Load(para1, u') // u' is first monad in graph or new monad
void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) { bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
if (need_replace_loads.size() == 0) {
return false;
}
constexpr size_t second_input_index = 2; constexpr size_t second_input_index = 2;
auto monad = GetFirstMonad(fg); auto monad = GetFirstMonad(fg);
for (const auto &load_node : need_replace_loads) { for (const auto &load_node : need_replace_loads) {
@ -331,6 +335,7 @@ void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode
auto mgr = fg->manager(); auto mgr = fg->manager();
mgr->SetEdge(load_node, second_input_index, monad); mgr->SetEdge(load_node, second_input_index, monad);
} }
return true;
} }
// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
@ -341,7 +346,10 @@ bool CSE::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const {
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return()); std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
std::vector<AnfNodePtr> need_replace_loads; std::vector<AnfNodePtr> need_replace_loads;
std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads);
ReplaceUpdateStateForLoad(fg, need_replace_loads); const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
if (update_state_replaced) {
changed = true;
}
// split group if there is no-load node between two load nodes. // split group if there is no-load node between two load nodes.
std::vector<std::vector<size_t>> need_merge_loads; std::vector<std::vector<size_t>> need_merge_loads;
for (auto &group : load_groups) { for (auto &group : load_groups) {

Loading…
Cancel
Save