diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc index f669792d12..eb1de63835 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/shape_ops_splitter.cc @@ -16,6 +16,7 @@ #include "backend/optimizer/graph_kernel/shape_ops_splitter.h" #include #include +#include #include #include #include @@ -50,18 +51,24 @@ AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) { } void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { - auto &users = mng->node_users(); - AnfNodePtrList splitted_nodes; - for (size_t i = 0; i < users[node].size(); ++i) { - splitted_nodes.push_back(CloneCNode(node)); + const auto &index_set = mng->node_users()[node]; + std::map> users_info; + std::for_each(index_set.cbegin(), index_set.cend(), [&users_info](const std::pair &iter) { + users_info[iter.first].push_back(iter.second); + }); + + AnfNodePtrList split_nodes; + for (size_t i = 0; i < users_info.size(); ++i) { + split_nodes.push_back(CloneCNode(node)); } - const auto &index_set = users[node]; int i = 0; - for (auto [user, index] : index_set) { + for (auto [user, indices] : users_info) { auto user_node = user->cast(); MS_EXCEPTION_IF_NULL(user_node); - user_node->set_input(index, splitted_nodes[i]); + for (auto index : indices) { + user_node->set_input(index, split_nodes[i]); + } i++; } } @@ -69,9 +76,11 @@ void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) { auto &users = mng->node_users(); - return users[node].size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), [&node](const PrimitivePtr &prim) { - return IsPrimitiveCNode(node, prim); - }); + std::set user_set; + std::transform(users[node].cbegin(), users[node].cend(), std::inserter(user_set, user_set.end()), + [](const std::pair &iter) { return iter.first; }); + return user_set.size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); } bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {