@ -41,6 +41,16 @@ inline size_t GetIndex(const AnfNodePtr &getitem_node) {
void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
auto getitem = getitem_node->cast<CNodePtr>();
auto idx_node = NewValueNode(MakeValue<int64_t>(SizeToLong(index)));
auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
bool merge_repeated_getitem = false) {
@ -85,7 +95,69 @@ AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
return result;
/* Merge the get_item nodes that have same index.
/* Unify the repeated output in a func_graph.
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* graph_kernel:
* %1 = TensorAdd(p1, p2)
* %2 = Reshape(%1)
* return make_tuple(%2, %2)
* -->
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 0) // changed the index to 0.
* graph_kernel:
* %1 = TensorAdd(p1, p2)
* %2 = Reshape(%1)
* return make_tuple(%2, %2)
class UnifyRepeatedOutput : public Pass {
bool Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
bool changed = false;
for (auto node : todos) {
if (CheckRepeatedOutput(AnfAlgo::GetCNodeFuncGraphPtr(node))) {
changed = true;
AnfNodePtrList getitem_list;
GetGraphKernelGetitemList(mng, node, &getitem_list, false);
if (getitem_list.size() != index_map_.size()) {
MS_LOG(EXCEPTION) << "getitem_list.size (" << getitem_list.size() << ") should be equal to index_map.size ("
<< index_map_.size() << ").";
for (size_t i = 0; i < index_map_.size(); ++i) {
if (index_map_[i] != i && getitem_list[i] != nullptr) {
SetIndex(getitem_list[i], index_map_[i]);
return changed;
bool CheckRepeatedOutput(const FuncGraphPtr &sub_func_graph) {
// the output should be a MakeTuple.
auto maketuple = sub_func_graph->output()->cast<CNodePtr>();
AnfNodePtrList outputs(maketuple->inputs().begin() + 1, maketuple->inputs().end());
bool found = false;
for (size_t i = 0; i < outputs.size(); ++i) {
index_map_[i] = std::find(outputs.begin(), outputs.begin() + i, outputs[i]) - outputs.begin();
if (index_map_[i] != i) {
found = true;
return found;
std::vector<size_t> index_map_;
/* Unify the get_item nodes that have same index.
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 0)
@ -95,13 +167,13 @@ AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
* %7 = user_z(%4)
* --->
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %2 = tuple_getitem(%1, 0) // unify the original %2 and %3
* %3 = tuple_getitem(%1, 1)
* %4 = user_x(%2)
* %5 = user_y(%2)
* %6 = user_z(%3)
class MergeRepeatedGetitem : public Pass {
class UnifyRepeatedGetitem : public Pass {
bool Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
@ -237,26 +309,24 @@ bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
mng = Manage(func_graph, true);
bool changed = std::make_shared<MergeRepeatedGetitem>()->Run(func_graph);
bool changed = false;
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
changed = std::make_shared<EliminateGetitemForControlDepend>()->Run(func_graph) || changed;
changed = Process(func_graph) || changed;
return changed;
void EliminateRedundantOutput::UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset) {
void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
if (offset == 0) return;
int64_t index = SizeToLong(GetIndex(getitem));
auto index = GetIndex(getitem);
if (offset > index) {
MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString();
index -= offset;
auto idx_node = NewValueNode(MakeValue<int64_t>(index));
auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
SetIndex(getitem, index);
AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
@ -266,14 +336,14 @@ AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, co
AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtrList abstract_list;
int64_t offset = 0;
size_t offset = 0;
for (size_t i = 0; i < getitems.size(); ++i) {
if (getitems[i] == nullptr) {
} else {
new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
UpdateGetitemIndex(getitems[i]->cast<CNodePtr>(), offset);
UpdateGetitemIndex(getitems[i], offset);
if (offset == 0) return nullptr;