|
|
@ -27,6 +27,19 @@ namespace ir {
|
|
|
|
|
|
|
|
|
|
|
|
size_t PDPattern::id_ = 0UL;
|
|
|
|
size_t PDPattern::id_ = 0UL;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PDNode* PDPattern::NewNode(const std::string& name) {
|
|
|
|
|
|
|
|
if (!name.empty()) {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
|
|
|
|
|
|
|
|
"PDNode's name should be unique, get duplicate [%s]",
|
|
|
|
|
|
|
|
name);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nodes_.emplace_back(new PDNode(this, name));
|
|
|
|
|
|
|
|
auto* cur = nodes_.back().get();
|
|
|
|
|
|
|
|
node_map_[name] = cur;
|
|
|
|
|
|
|
|
return cur;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
|
|
|
|
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
|
|
|
|
if (!name.empty()) {
|
|
|
|
if (!name.empty()) {
|
|
|
|
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
|
|
|
|
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
|
|
|
@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
|
|
|
|
return cur;
|
|
|
|
return cur;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PDNode* PDPattern::RetriveNode(const std::string& id) const {
|
|
|
|
PDNode* PDPattern::RetrieveNode(const std::string& id) const {
|
|
|
|
auto it = node_map_.find(id);
|
|
|
|
auto it = node_map_.find(id);
|
|
|
|
if (it == node_map_.end()) {
|
|
|
|
if (it == node_map_.end()) {
|
|
|
|
return nullptr;
|
|
|
|
return nullptr;
|
|
|
@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
|
|
|
|
auto subgraphs = DetectPatterns();
|
|
|
|
auto subgraphs = DetectPatterns();
|
|
|
|
UniquePatterns(&subgraphs);
|
|
|
|
UniquePatterns(&subgraphs);
|
|
|
|
RemoveOverlappedMatch(&subgraphs);
|
|
|
|
RemoveOverlappedMatch(&subgraphs);
|
|
|
|
|
|
|
|
ValidateByNodeRole(&subgraphs);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (subgraphs.empty()) return;
|
|
|
|
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
|
|
|
|
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
|
|
|
|
int id = 0;
|
|
|
|
int id = 0;
|
|
|
|
for (auto& g : subgraphs) {
|
|
|
|
for (auto& g : subgraphs) {
|
|
|
@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check to early stop if some PDNode can't find matched Node.
|
|
|
|
|
|
|
|
for (auto& pdnode : pattern_.nodes()) {
|
|
|
|
|
|
|
|
if (!pdnodes2nodes_.count(pdnode.get())) {
|
|
|
|
|
|
|
|
VLOG(4) << pdnode->name() << " can't find matched Node, early stop";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
|
|
|
|
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
|
|
|
|
return !pdnodes2nodes_.empty();
|
|
|
|
return !pdnodes2nodes_.empty();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// The intermediate Nodes can only link to the nodes inside the pattern, or this
|
|
|
|
|
|
|
|
// subgraph will be droped.
|
|
|
|
|
|
|
|
void GraphPatternDetector::ValidateByNodeRole(
|
|
|
|
|
|
|
|
std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
|
|
|
|
|
|
|
|
std::vector<GraphPatternDetector::subgraph_t> result;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
subgraphs->erase(
|
|
|
|
|
|
|
|
std::remove_if(
|
|
|
|
|
|
|
|
subgraphs->begin(), subgraphs->end(),
|
|
|
|
|
|
|
|
[](const GraphPatternDetector::subgraph_t& subgraph) -> bool {
|
|
|
|
|
|
|
|
// Collect the inputs and outputs.
|
|
|
|
|
|
|
|
std::unordered_set<Node*> ios;
|
|
|
|
|
|
|
|
for (auto& item : subgraph) {
|
|
|
|
|
|
|
|
if (!item.first->IsIntermediate()) {
|
|
|
|
|
|
|
|
ios.insert(item.second);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto& item : subgraph) {
|
|
|
|
|
|
|
|
if (item.first->IsIntermediate()) {
|
|
|
|
|
|
|
|
for (auto* x : item.second->inputs) {
|
|
|
|
|
|
|
|
if (!ios.count(x)) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto* x : item.second->outputs) {
|
|
|
|
|
|
|
|
if (!ios.count(x)) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}),
|
|
|
|
|
|
|
|
subgraphs->end());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
struct HitGroup {
|
|
|
|
struct HitGroup {
|
|
|
|
std::unordered_map<PDNode*, Node*> roles;
|
|
|
|
std::unordered_map<PDNode*, Node*> roles;
|
|
|
|
|
|
|
|
|
|
|
@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
|
|
|
|
// in edges of PDNodes.
|
|
|
|
// in edges of PDNodes.
|
|
|
|
for (const auto& edge : pattern_.edges()) {
|
|
|
|
for (const auto& edge : pattern_.edges()) {
|
|
|
|
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name();
|
|
|
|
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name();
|
|
|
|
|
|
|
|
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
|
|
|
|
// Each role has two PDNodes, which indicates two roles.
|
|
|
|
// Each role has two PDNodes, which indicates two roles.
|
|
|
|
// Detect two Nodes that can match these two roles and they are connected.
|
|
|
|
// Detect two Nodes that can match these two roles and they are connected.
|
|
|
|
auto& pre_groups = bi_records[step % 2];
|
|
|
|
auto& pre_groups = bi_records[step % 2];
|
|
|
@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
|
|
|
|
// source -> target
|
|
|
|
// source -> target
|
|
|
|
for (Node* source : pdnodes2nodes_[edge.first]) {
|
|
|
|
for (Node* source : pdnodes2nodes_[edge.first]) {
|
|
|
|
for (Node* target : pdnodes2nodes_[edge.second]) {
|
|
|
|
for (Node* target : pdnodes2nodes_[edge.second]) {
|
|
|
|
|
|
|
|
VLOG(8) << "check " << source->id() << " -- " << target->id();
|
|
|
|
// TODO(Superjomn) add some prune strategies.
|
|
|
|
// TODO(Superjomn) add some prune strategies.
|
|
|
|
for (const auto& group : pre_groups) {
|
|
|
|
for (const auto& group : pre_groups) {
|
|
|
|
HitGroup new_group = group;
|
|
|
|
HitGroup new_group = group;
|
|
|
@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
VLOG(3) << "step " << step << " get records: " << cur_groups.size();
|
|
|
|
VLOG(3) << "step " << step << " get records: " << cur_groups.size();
|
|
|
|
|
|
|
|
for (auto& group : cur_groups) {
|
|
|
|
|
|
|
|
for (auto& item : group.roles) {
|
|
|
|
|
|
|
|
VLOG(4) << "node " << item.second->id() << " as " << item.first->name();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(4) << "=========================================================";
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& group : bi_records[step % 2]) {
|
|
|
|
for (auto& group : bi_records[step % 2]) {
|
|
|
@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
|
|
|
|
return *this;
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_op() {
|
|
|
|
|
|
|
|
asserts_.emplace_back([this](Node* x) { return x && x->IsOp(); });
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_op(const std::string& op_type) {
|
|
|
|
|
|
|
|
asserts_.emplace_back([this, op_type](Node* x) {
|
|
|
|
|
|
|
|
return x && x->IsOp() && x->Op()->Type() == op_type;
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_var() {
|
|
|
|
|
|
|
|
asserts_.emplace_back([this](Node* x) { return x && x->IsVar(); });
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_var_not_persistable() {
|
|
|
|
|
|
|
|
assert_is_var();
|
|
|
|
|
|
|
|
asserts_.emplace_back([this](Node* x) { return !x->Var()->Persistable(); });
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_persistable_var() {
|
|
|
|
|
|
|
|
assert_is_var();
|
|
|
|
|
|
|
|
asserts_.emplace_back([=](Node* x) { return x->Var()->Persistable(); });
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
|
|
|
|
|
|
|
|
const std::string& argument, int nth) {
|
|
|
|
|
|
|
|
assert_is_var();
|
|
|
|
|
|
|
|
assert_is_op_input(op_type);
|
|
|
|
|
|
|
|
asserts_.emplace_back([=](Node* x) {
|
|
|
|
|
|
|
|
for (auto* op : x->outputs) {
|
|
|
|
|
|
|
|
if (IsNthInput(x, op, argument, nth)) return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
|
|
|
|
|
|
|
|
const std::string& argument, int nth) {
|
|
|
|
|
|
|
|
assert_is_var();
|
|
|
|
|
|
|
|
asserts_.emplace_back([=](Node* x) {
|
|
|
|
|
|
|
|
for (auto* op : x->inputs) {
|
|
|
|
|
|
|
|
if (IsNthOutput(x, op, argument, nth)) return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) {
|
|
|
|
|
|
|
|
assert_is_var();
|
|
|
|
|
|
|
|
asserts_.emplace_back([=](Node* x) {
|
|
|
|
|
|
|
|
for (auto* op : x->outputs) {
|
|
|
|
|
|
|
|
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
|
|
|
|
|
|
|
|
op->inputs.size() == 1) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) {
|
|
|
|
|
|
|
|
assert_is_var();
|
|
|
|
|
|
|
|
asserts_.emplace_back([=](Node* x) {
|
|
|
|
|
|
|
|
for (auto* op : x->inputs) {
|
|
|
|
|
|
|
|
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
|
|
|
|
|
|
|
|
op->outputs.size() == 1) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
|
|
|
|
|
|
|
|
assert_is_var();
|
|
|
|
|
|
|
|
asserts_.emplace_back([=](Node* x) {
|
|
|
|
|
|
|
|
for (auto* op : x->inputs) {
|
|
|
|
|
|
|
|
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
|
|
|
|
|
|
|
|
assert_is_var();
|
|
|
|
|
|
|
|
asserts_.emplace_back([=](Node* x) {
|
|
|
|
|
|
|
|
for (auto* op : x->outputs) {
|
|
|
|
|
|
|
|
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) {
|
|
|
|
|
|
|
|
assert_is_op(op_type);
|
|
|
|
|
|
|
|
asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; });
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_op_has_n_outputs(const std::string& op_type, size_t n) {
|
|
|
|
|
|
|
|
assert_is_op(op_type);
|
|
|
|
|
|
|
|
asserts_.emplace_back([=](Node* x) { return x->outputs.size() == n; });
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PDNode* PDNode::assert_more(PDNode::teller_t&& teller) {
|
|
|
|
|
|
|
|
asserts_.emplace_back(std::move(teller));
|
|
|
|
|
|
|
|
return this;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
} // namespace ir
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|