|
|
|
@ -71,7 +71,7 @@ AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// Replace UpdateState with the input monad.
|
|
|
|
|
return update_state->inputs().at(kInputIndex);
|
|
|
|
|
return update_state->input(kInputIndex);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Eliminate UpdateState that attaches a pure (no-side-effect) node.
|
|
|
|
@ -100,7 +100,7 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Remove UpdateState by replace it with its input monad.
|
|
|
|
|
return update_state->inputs().at(kInputIndex);
|
|
|
|
|
return update_state->input(kInputIndex);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Eliminate redundant UpdateState/Depend pair nodes caused by inline.
|
|
|
|
@ -118,7 +118,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
|
|
|
|
|
// Skip if Depend attach input is not a monad.
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto update_monad = update_state->inputs().at(kInputIndex);
|
|
|
|
|
auto update_monad = update_state->input(kInputIndex);
|
|
|
|
|
if (!HasAbstractMonad(update_monad)) {
|
|
|
|
|
// Skip if UpdateState input is not a monad.
|
|
|
|
|
MS_LOG(WARNING) << "Not a monad input: " << update_state->DebugString();
|
|
|
|
@ -139,7 +139,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN
|
|
|
|
|
}
|
|
|
|
|
// Replace Depend with its input.
|
|
|
|
|
if (depend->size() == kMinDependSize) {
|
|
|
|
|
auto depend_input = depend->inputs().at(kInputIndex);
|
|
|
|
|
auto depend_input = depend->input(kInputIndex);
|
|
|
|
|
mgr->Replace(depend, depend_input);
|
|
|
|
|
} else {
|
|
|
|
|
auto inputs = depend->inputs();
|
|
|
|
@ -163,7 +163,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN
|
|
|
|
|
if (make_tuple->size() != kMakeTupleSize) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto &node = make_tuple->inputs().at(kAttachIndex);
|
|
|
|
|
auto &node = make_tuple->input(kAttachIndex);
|
|
|
|
|
auto node_abs = node->abstract();
|
|
|
|
|
if (node_abs == nullptr || !node_abs->isa<abstract::AbstractError>()) {
|
|
|
|
|
return nullptr;
|
|
|
|
@ -173,7 +173,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// Create a new UpdateState to replace the old one.
|
|
|
|
|
const auto &attach = make_tuple->inputs().at(kInputIndex);
|
|
|
|
|
const auto &attach = make_tuple->input(kInputIndex);
|
|
|
|
|
auto new_update_state = fg->NewCNode({update_state->input(0), update_state->input(1), attach});
|
|
|
|
|
new_update_state->set_abstract(update_state->abstract());
|
|
|
|
|
new_update_state->set_scope(update_state->scope());
|
|
|
|
@ -206,42 +206,47 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c
|
|
|
|
|
if (make_tuple->size() != kMakeTupleSize) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto &first_input = make_tuple->inputs().at(kInputIndex);
|
|
|
|
|
auto &first_input = make_tuple->input(kInputIndex);
|
|
|
|
|
if (IsValueNode<FuncGraph>(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) {
|
|
|
|
|
return update_state->input(1);
|
|
|
|
|
}
|
|
|
|
|
auto &second_input = make_tuple->inputs().at(kAttachIndex);
|
|
|
|
|
auto &second_input = make_tuple->input(kAttachIndex);
|
|
|
|
|
if (IsValueNode<FuncGraph>(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) {
|
|
|
|
|
return update_state->input(1);
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads);
|
|
|
|
|
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads);
|
|
|
|
|
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
|
|
|
|
|
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
|
|
|
|
|
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads);
|
|
|
|
|
|
|
|
|
|
// Search consecutive load nodes from UpdateState node.
|
|
|
|
|
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *loads) {
|
|
|
|
|
auto &attach = update_state->inputs().at(kAttachIndex);
|
|
|
|
|
size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states,
|
|
|
|
|
std::vector<CNodePtr> *loads) {
|
|
|
|
|
auto &attach = update_state->input(kAttachIndex);
|
|
|
|
|
if (IsPrimitiveCNode(attach, prim::kPrimLoad)) {
|
|
|
|
|
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), loads);
|
|
|
|
|
update_states->emplace_back(update_state);
|
|
|
|
|
return GetLoadsFollowLoad(attach->cast<CNodePtr>(), update_states, loads);
|
|
|
|
|
}
|
|
|
|
|
if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) {
|
|
|
|
|
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), loads);
|
|
|
|
|
update_states->emplace_back(update_state);
|
|
|
|
|
return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads);
|
|
|
|
|
}
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads) {
|
|
|
|
|
loads->push_back(load);
|
|
|
|
|
auto &load_attach = load->inputs().at(kAttachIndex);
|
|
|
|
|
size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
|
|
|
|
|
loads->emplace_back(load);
|
|
|
|
|
auto &load_attach = load->input(kAttachIndex);
|
|
|
|
|
if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) {
|
|
|
|
|
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), loads) + 1;
|
|
|
|
|
return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads) + 1;
|
|
|
|
|
}
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads) {
|
|
|
|
|
size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple,
|
|
|
|
|
std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) {
|
|
|
|
|
if (!OnlyUpdateStateUse(update_state, make_tuple)) {
|
|
|
|
|
// UpdateState should be the only user of
|
|
|
|
|
return 0;
|
|
|
|
@ -256,12 +261,12 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu
|
|
|
|
|
// Add load nodes from tuple elements.
|
|
|
|
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
|
|
|
|
auto &element = inputs.at(i);
|
|
|
|
|
loads->push_back(element->cast<CNodePtr>());
|
|
|
|
|
loads->emplace_back(element->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
// Follow prev update state if found.
|
|
|
|
|
auto prev_node = update_state->inputs().at(kInputIndex);
|
|
|
|
|
auto prev_node = update_state->input(kInputIndex);
|
|
|
|
|
if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) {
|
|
|
|
|
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), loads) + 1;
|
|
|
|
|
return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads) + 1;
|
|
|
|
|
}
|
|
|
|
|
return 1;
|
|
|
|
|
}
|
|
|
|
@ -301,7 +306,8 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd
|
|
|
|
|
// xN = Load(xN, u)
|
|
|
|
|
// t = make_tuple(x1, x2, ... , xN)
|
|
|
|
|
// u1 = UpdateState(u, t)
|
|
|
|
|
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &loads) {
|
|
|
|
|
AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states,
|
|
|
|
|
const std::vector<CNodePtr> &loads) {
|
|
|
|
|
auto fg = old_update_state->func_graph();
|
|
|
|
|
if (fg == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
@ -315,20 +321,24 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const
|
|
|
|
|
std::set<AnfNodePtr> loaded_para_set;
|
|
|
|
|
make_tuple_inputs.reserve(loads.size() + 1);
|
|
|
|
|
make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
|
|
|
|
auto input_monad = loads.back()->inputs().at(kAttachIndex);
|
|
|
|
|
auto input_monad = loads.back()->input(kAttachIndex);
|
|
|
|
|
for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) {
|
|
|
|
|
auto &load = *iter;
|
|
|
|
|
auto result = loaded_para_set.emplace(load->inputs().at(kInputIndex));
|
|
|
|
|
auto result = loaded_para_set.emplace(load->input(kInputIndex));
|
|
|
|
|
const bool is_new_load = result.second;
|
|
|
|
|
if (is_new_load) {
|
|
|
|
|
// Put Load node as a tuple element, if the parameter is not loaded by other Load.
|
|
|
|
|
make_tuple_inputs.emplace_back(load);
|
|
|
|
|
}
|
|
|
|
|
if (load->inputs().at(kAttachIndex) != input_monad) {
|
|
|
|
|
if (load->input(kAttachIndex) != input_monad) {
|
|
|
|
|
// Set all load use same input monad.
|
|
|
|
|
mgr->SetEdge(load, kAttachIndex, input_monad);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto i = update_states.size() - 1; i > 0; i--) {
|
|
|
|
|
auto &us = update_states[i];
|
|
|
|
|
mgr->Replace(us, us->input(kInputIndex));
|
|
|
|
|
}
|
|
|
|
|
if (make_tuple_inputs.size() == 1) {
|
|
|
|
|
// This should not happen.
|
|
|
|
|
MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2);
|
|
|
|
@ -538,7 +548,7 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
|
|
|
|
|
MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto &attach = update_state_node->inputs().at(kAttachIndex);
|
|
|
|
|
auto &attach = update_state_node->input(kAttachIndex);
|
|
|
|
|
if (IsPrimitiveCNode(attach, prim::kPrimDepend)) {
|
|
|
|
|
return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
@ -586,9 +596,10 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode
|
|
|
|
|
return new_node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<CNodePtr> update_states;
|
|
|
|
|
std::vector<CNodePtr> loads;
|
|
|
|
|
if (GetLoadsFromUpdateState(update_state_node, &loads) > 1 && loads.size() > 1) {
|
|
|
|
|
return EliminateUpdateStateForLoads(update_state_node, loads);
|
|
|
|
|
if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) {
|
|
|
|
|
return EliminateUpdateStateForLoads(update_state_node, update_states, loads);
|
|
|
|
|
}
|
|
|
|
|
// Eliminate UpdateStates that attaches a no-side-effect node.
|
|
|
|
|
if (!attach_is_load && !attach_is_tuple) {
|
|
|
|
|