|
|
@ -133,7 +133,8 @@ std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, cons
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
|
|
|
|
if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
|
|
|
|
for (const auto &input : cnode->inputs()) {
|
|
|
|
for (const auto &input : cnode->inputs()) {
|
|
|
|
if (input->isa<Parameter>()) {
|
|
|
|
if (input->isa<Parameter>() ||
|
|
|
|
|
|
|
|
(IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>())) {
|
|
|
|
unload_users_record.insert(input);
|
|
|
|
unload_users_record.insert(input);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -317,7 +318,10 @@ AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
|
|
|
|
// To:
|
|
|
|
// To:
|
|
|
|
// u1 = UpdateState(u, c)
|
|
|
|
// u1 = UpdateState(u, c)
|
|
|
|
// p1 = Load(para1, u') // u' is first monad in graph or new monad
|
|
|
|
// p1 = Load(para1, u') // u' is first monad in graph or new monad
|
|
|
|
void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
|
|
|
|
bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
|
|
|
|
|
|
|
|
if (need_replace_loads.size() == 0) {
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
constexpr size_t second_input_index = 2;
|
|
|
|
constexpr size_t second_input_index = 2;
|
|
|
|
auto monad = GetFirstMonad(fg);
|
|
|
|
auto monad = GetFirstMonad(fg);
|
|
|
|
for (const auto &load_node : need_replace_loads) {
|
|
|
|
for (const auto &load_node : need_replace_loads) {
|
|
|
@ -331,6 +335,7 @@ void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNode
|
|
|
|
auto mgr = fg->manager();
|
|
|
|
auto mgr = fg->manager();
|
|
|
|
mgr->SetEdge(load_node, second_input_index, monad);
|
|
|
|
mgr->SetEdge(load_node, second_input_index, monad);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
|
|
|
|
// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
|
|
|
@ -341,7 +346,10 @@ bool CSE::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const {
|
|
|
|
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
|
|
|
|
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
|
|
|
|
std::vector<AnfNodePtr> need_replace_loads;
|
|
|
|
std::vector<AnfNodePtr> need_replace_loads;
|
|
|
|
std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads);
|
|
|
|
std::vector<std::vector<size_t>> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads);
|
|
|
|
ReplaceUpdateStateForLoad(fg, need_replace_loads);
|
|
|
|
const bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
|
|
|
|
|
|
|
|
if (update_state_replaced) {
|
|
|
|
|
|
|
|
changed = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
// split group if there is no-load node between two load nodes.
|
|
|
|
// split group if there is no-load node between two load nodes.
|
|
|
|
std::vector<std::vector<size_t>> need_merge_loads;
|
|
|
|
std::vector<std::vector<size_t>> need_merge_loads;
|
|
|
|
for (auto &group : load_groups) {
|
|
|
|
for (auto &group : load_groups) {
|
|
|
|