|
|
|
@ -259,6 +259,15 @@ GraphPatternDetector::DetectPatterns() {
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GraphItemCMP(const std::pair<PDNode *, Node *> &a,
|
|
|
|
|
const std::pair<PDNode *, Node *> &b) {
|
|
|
|
|
if (a.first != b.first) {
|
|
|
|
|
return a.first < b.first;
|
|
|
|
|
} else {
|
|
|
|
|
return a.second < b.second;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(Superjomn) enhance the function as it marks unique unique as duplicates
|
|
|
|
|
// see https://github.com/PaddlePaddle/Paddle/issues/13550
|
|
|
|
|
void GraphPatternDetector::UniquePatterns(
|
|
|
|
@ -267,12 +276,16 @@ void GraphPatternDetector::UniquePatterns(
|
|
|
|
|
std::vector<GraphPatternDetector::subgraph_t> result;
|
|
|
|
|
|
|
|
|
|
std::unordered_set<size_t> set;
|
|
|
|
|
std::hash<std::string> hasher;
|
|
|
|
|
for (auto &g : *subgraphs) {
|
|
|
|
|
size_t key = 0;
|
|
|
|
|
for (auto &item : g) {
|
|
|
|
|
key ^= std::hash<void *>{}(item.first);
|
|
|
|
|
key ^= std::hash<void *>{}(item.second);
|
|
|
|
|
}
|
|
|
|
|
// Sort the items in the sub-graph, and transform to a string key.
|
|
|
|
|
std::vector<std::pair<PDNode *, Node *>> sorted_keys(g.begin(), g.end());
|
|
|
|
|
std::sort(sorted_keys.begin(), sorted_keys.end(), GraphItemCMP);
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
for (auto &item : sorted_keys) {
|
|
|
|
|
ss << item.first << ":" << item.second;
|
|
|
|
|
}
|
|
|
|
|
auto key = hasher(ss.str());
|
|
|
|
|
if (!set.count(key)) {
|
|
|
|
|
result.emplace_back(g);
|
|
|
|
|
set.insert(key);
|
|
|
|
|