|
|
|
@ -66,36 +66,23 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode
|
|
|
|
|
return EXCLUDE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The GetItem node should be fused with its real input and users.
|
|
|
|
|
// The GetItem node should be fused with its real input.
|
|
|
|
|
// If its real input is not in the fuse_list, the GetItem should be excluded.
|
|
|
|
|
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
|
|
|
|
|
if (fused_op.empty()) return AnfNodePtrList();
|
|
|
|
|
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
|
|
|
|
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
|
|
|
|
|
|
|
|
|
|
auto mng = fused_op[0]->func_graph()->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mng);
|
|
|
|
|
bool changed = true;
|
|
|
|
|
while (changed) {
|
|
|
|
|
changed = false;
|
|
|
|
|
AnfNodePtrList remove_list;
|
|
|
|
|
for (auto getitem : fused_op_set) {
|
|
|
|
|
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
|
|
|
|
|
|
|
|
|
|
// GetItem should be fused with its real input.
|
|
|
|
|
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
|
|
|
|
if (check_include(prev_node) == EXCLUDE) {
|
|
|
|
|
remove_list.push_back(getitem);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetItem should be fused with its all users.
|
|
|
|
|
const auto &users = mng->node_users()[getitem];
|
|
|
|
|
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
|
|
|
|
|
return check_include(user.first) == EXCLUDE;
|
|
|
|
|
})) {
|
|
|
|
|
remove_list = DeepLinkedGraphSearch(getitem, check_include);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!remove_list.empty()) {
|
|
|
|
|