eliminate redundant split ops

pull/12301/head
tronzhang 4 years ago
parent ca37351927
commit e953705521

@ -16,6 +16,7 @@
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
#include <algorithm>
#include <vector>
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
@ -50,18 +51,24 @@ AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
}
void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
auto &users = mng->node_users();
AnfNodePtrList splitted_nodes;
for (size_t i = 0; i < users[node].size(); ++i) {
splitted_nodes.push_back(CloneCNode(node));
const auto &index_set = mng->node_users()[node];
std::map<AnfNodePtr, std::vector<int>> users_info;
std::for_each(index_set.cbegin(), index_set.cend(), [&users_info](const std::pair<AnfNodePtr, int> &iter) {
users_info[iter.first].push_back(iter.second);
});
AnfNodePtrList split_nodes;
for (size_t i = 0; i < users_info.size(); ++i) {
split_nodes.push_back(CloneCNode(node));
}
const auto &index_set = users[node];
int i = 0;
for (auto [user, index] : index_set) {
for (auto [user, indices] : users_info) {
auto user_node = user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_node);
user_node->set_input(index, splitted_nodes[i]);
for (auto index : indices) {
user_node->set_input(index, split_nodes[i]);
}
i++;
}
}
@ -69,9 +76,11 @@ void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
auto &users = mng->node_users();
return users[node].size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), [&node](const PrimitivePtr &prim) {
return IsPrimitiveCNode(node, prim);
});
std::set<AnfNodePtr> user_set;
std::transform(users[node].cbegin(), users[node].cend(), std::inserter(user_set, user_set.end()),
[](const std::pair<AnfNodePtr, int> &iter) { return iter.first; });
return user_set.size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {

Loading…
Cancel
Save