|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <utility>
|
|
|
|
@ -50,18 +51,24 @@ AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
|
|
|
|
auto &users = mng->node_users();
|
|
|
|
|
AnfNodePtrList splitted_nodes;
|
|
|
|
|
for (size_t i = 0; i < users[node].size(); ++i) {
|
|
|
|
|
splitted_nodes.push_back(CloneCNode(node));
|
|
|
|
|
const auto &index_set = mng->node_users()[node];
|
|
|
|
|
std::map<AnfNodePtr, std::vector<int>> users_info;
|
|
|
|
|
std::for_each(index_set.cbegin(), index_set.cend(), [&users_info](const std::pair<AnfNodePtr, int> &iter) {
|
|
|
|
|
users_info[iter.first].push_back(iter.second);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
AnfNodePtrList split_nodes;
|
|
|
|
|
for (size_t i = 0; i < users_info.size(); ++i) {
|
|
|
|
|
split_nodes.push_back(CloneCNode(node));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto &index_set = users[node];
|
|
|
|
|
int i = 0;
|
|
|
|
|
for (auto [user, index] : index_set) {
|
|
|
|
|
for (auto [user, indices] : users_info) {
|
|
|
|
|
auto user_node = user->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(user_node);
|
|
|
|
|
user_node->set_input(index, splitted_nodes[i]);
|
|
|
|
|
for (auto index : indices) {
|
|
|
|
|
user_node->set_input(index, split_nodes[i]);
|
|
|
|
|
}
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -69,9 +76,11 @@ void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
|
|
|
|
|
|
|
|
|
bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
|
|
|
|
auto &users = mng->node_users();
|
|
|
|
|
return users[node].size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), [&node](const PrimitivePtr &prim) {
|
|
|
|
|
return IsPrimitiveCNode(node, prim);
|
|
|
|
|
});
|
|
|
|
|
std::set<AnfNodePtr> user_set;
|
|
|
|
|
std::transform(users[node].cbegin(), users[node].cend(), std::inserter(user_set, user_set.end()),
|
|
|
|
|
[](const std::pair<AnfNodePtr, int> &iter) { return iter.first; });
|
|
|
|
|
return user_set.size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(),
|
|
|
|
|
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {
|
|
|
|
|