|
|
|
@ -41,6 +41,16 @@ inline size_t GetIndex(const AnfNodePtr &getitem_node) {
|
|
|
|
|
getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
|
|
|
|
|
auto getitem = getitem_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(getitem);
|
|
|
|
|
auto idx_node = NewValueNode(MakeValue<int64_t>(SizeToLong(index)));
|
|
|
|
|
auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
|
|
|
|
idx_node->set_abstract(abstract);
|
|
|
|
|
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
|
|
|
|
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
|
|
|
|
|
bool merge_repeated_getitem = false) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mng);
|
|
|
|
@ -85,7 +95,69 @@ AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/* Merge the get_item nodes that have same index.
|
|
|
|
|
/* Unify the repeated output in a func_graph.
|
|
|
|
|
* %1 = call @graph_kernel(p1, p2)
|
|
|
|
|
* %2 = tuple_getitem(%1, 0)
|
|
|
|
|
* %3 = tuple_getitem(%1, 1)
|
|
|
|
|
* graph_kernel:
|
|
|
|
|
* %1 = TensorAdd(p1, p2)
|
|
|
|
|
* %2 = Reshape(%1)
|
|
|
|
|
* return make_tuple(%2, %2)
|
|
|
|
|
* -->
|
|
|
|
|
* %1 = call @graph_kernel(p1, p2)
|
|
|
|
|
* %2 = tuple_getitem(%1, 0)
|
|
|
|
|
* %3 = tuple_getitem(%1, 0) // changed the index to 0.
|
|
|
|
|
* graph_kernel:
|
|
|
|
|
* %1 = TensorAdd(p1, p2)
|
|
|
|
|
* %2 = Reshape(%1)
|
|
|
|
|
* return make_tuple(%2, %2)
|
|
|
|
|
*/
|
|
|
|
|
class UnifyRepeatedOutput : public Pass {
|
|
|
|
|
public:
|
|
|
|
|
bool Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
auto mng = func_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mng);
|
|
|
|
|
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
|
|
|
|
|
bool changed = false;
|
|
|
|
|
for (auto node : todos) {
|
|
|
|
|
if (CheckRepeatedOutput(AnfAlgo::GetCNodeFuncGraphPtr(node))) {
|
|
|
|
|
changed = true;
|
|
|
|
|
AnfNodePtrList getitem_list;
|
|
|
|
|
GetGraphKernelGetitemList(mng, node, &getitem_list, false);
|
|
|
|
|
if (getitem_list.size() != index_map_.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "getitem_list.size (" << getitem_list.size() << ") should be equal to index_map.size ("
|
|
|
|
|
<< index_map_.size() << ").";
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < index_map_.size(); ++i) {
|
|
|
|
|
if (index_map_[i] != i && getitem_list[i] != nullptr) {
|
|
|
|
|
SetIndex(getitem_list[i], index_map_[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return changed;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool CheckRepeatedOutput(const FuncGraphPtr &sub_func_graph) {
|
|
|
|
|
// the output should be a MakeTuple.
|
|
|
|
|
auto maketuple = sub_func_graph->output()->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(maketuple);
|
|
|
|
|
AnfNodePtrList outputs(maketuple->inputs().begin() + 1, maketuple->inputs().end());
|
|
|
|
|
index_map_.resize(outputs.size());
|
|
|
|
|
bool found = false;
|
|
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
|
|
|
|
index_map_[i] = std::find(outputs.begin(), outputs.begin() + i, outputs[i]) - outputs.begin();
|
|
|
|
|
if (index_map_[i] != i) {
|
|
|
|
|
found = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return found;
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> index_map_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/* Unify the get_item nodes that have same index.
|
|
|
|
|
* %1 = call @graph_kernel(p1, p2)
|
|
|
|
|
* %2 = tuple_getitem(%1, 0)
|
|
|
|
|
* %3 = tuple_getitem(%1, 0)
|
|
|
|
@ -95,13 +167,13 @@ AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
|
|
|
|
|
* %7 = user_z(%4)
|
|
|
|
|
* --->
|
|
|
|
|
* %1 = call @graph_kernel(p1, p2)
|
|
|
|
|
* %2 = tuple_getitem(%1, 0)
|
|
|
|
|
* %2 = tuple_getitem(%1, 0) // unify the original %2 and %3
|
|
|
|
|
* %3 = tuple_getitem(%1, 1)
|
|
|
|
|
* %4 = user_x(%2)
|
|
|
|
|
* %5 = user_y(%2)
|
|
|
|
|
* %6 = user_z(%3)
|
|
|
|
|
*/
|
|
|
|
|
class MergeRepeatedGetitem : public Pass {
|
|
|
|
|
class UnifyRepeatedGetitem : public Pass {
|
|
|
|
|
public:
|
|
|
|
|
bool Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
auto mng = func_graph->manager();
|
|
|
|
@ -237,26 +309,24 @@ bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
mng = Manage(func_graph, true);
|
|
|
|
|
func_graph->set_manager(mng);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool changed = std::make_shared<MergeRepeatedGetitem>()->Run(func_graph);
|
|
|
|
|
bool changed = false;
|
|
|
|
|
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
|
|
|
|
|
changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
|
|
|
|
|
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
|
|
|
|
|
changed = std::make_shared<EliminateGetitemForControlDepend>()->Run(func_graph) || changed;
|
|
|
|
|
changed = Process(func_graph) || changed;
|
|
|
|
|
return changed;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void EliminateRedundantOutput::UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset) {
|
|
|
|
|
void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
|
|
|
|
|
if (offset == 0) return;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(getitem);
|
|
|
|
|
int64_t index = SizeToLong(GetIndex(getitem));
|
|
|
|
|
auto index = GetIndex(getitem);
|
|
|
|
|
if (offset > index) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString();
|
|
|
|
|
}
|
|
|
|
|
index -= offset;
|
|
|
|
|
auto idx_node = NewValueNode(MakeValue<int64_t>(index));
|
|
|
|
|
auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
|
|
|
|
idx_node->set_abstract(abstract);
|
|
|
|
|
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
|
|
|
|
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
|
|
|
|
|
SetIndex(getitem, index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
|
|
|
|
@ -266,14 +336,14 @@ AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, co
|
|
|
|
|
MS_EXCEPTION_IF_NULL(old_maketuple);
|
|
|
|
|
AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
AbstractBasePtrList abstract_list;
|
|
|
|
|
int64_t offset = 0;
|
|
|
|
|
size_t offset = 0;
|
|
|
|
|
for (size_t i = 0; i < getitems.size(); ++i) {
|
|
|
|
|
if (getitems[i] == nullptr) {
|
|
|
|
|
offset++;
|
|
|
|
|
} else {
|
|
|
|
|
new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
|
|
|
|
|
abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
|
|
|
|
|
UpdateGetitemIndex(getitems[i]->cast<CNodePtr>(), offset);
|
|
|
|
|
UpdateGetitemIndex(getitems[i], offset);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (offset == 0) return nullptr;
|
|
|
|
|