|
|
|
@ -167,10 +167,12 @@ struct HitGroup {
|
|
|
|
|
|
|
|
|
|
bool Match(Node *node, PDNode *pat) {
|
|
|
|
|
if (nodes_.count(node)) {
|
|
|
|
|
if (!roles.count(pat)) return false;
|
|
|
|
|
return roles[pat] == node;
|
|
|
|
|
if (roles.count(pat) && roles[pat] == node) return true;
|
|
|
|
|
return false;
|
|
|
|
|
} else {
|
|
|
|
|
if (roles.count(pat) && roles[pat] != node) return false;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return !roles.count(pat) || roles.at(pat) == node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Register(Node *node, PDNode *pat) {
|
|
|
|
@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() {
|
|
|
|
|
std::vector<GraphPatternDetector::subgraph_t> result;
|
|
|
|
|
std::vector<HitGroup> init_groups;
|
|
|
|
|
std::array<std::vector<HitGroup>, 2> bi_records;
|
|
|
|
|
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
|
|
|
|
|
auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
|
|
|
|
|
: pattern_.edges().front().first;
|
|
|
|
|
if (!pdnodes2nodes_.count(first_pnode)) return result;
|
|
|
|
@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() {
|
|
|
|
|
VLOG(80) << "check " << source->id() << " -- " << target->id();
|
|
|
|
|
// TODO(Superjomn) add some prune strategies.
|
|
|
|
|
for (const auto &group : pre_groups) {
|
|
|
|
|
HitGroup new_group = group;
|
|
|
|
|
if (IsNodesLink(source, target) &&
|
|
|
|
|
new_group.Match(source, edge.first)) {
|
|
|
|
|
new_group.Register(source, edge.first);
|
|
|
|
|
if (new_group.Match(target, edge.second)) {
|
|
|
|
|
if (IsNodesLink(source, target)) {
|
|
|
|
|
HitGroup new_group = group;
|
|
|
|
|
bool flag = new_group.Match(source, edge.first) &&
|
|
|
|
|
new_group.Match(target, edge.second);
|
|
|
|
|
if (flag) {
|
|
|
|
|
new_group.Register(source, edge.first);
|
|
|
|
|
new_group.Register(target, edge.second);
|
|
|
|
|
cur_groups.push_back(new_group);
|
|
|
|
|
// TODO(Superjomn) need to unique
|
|
|
|
|