|
|
|
@ -19,6 +19,7 @@
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_helper.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_traits.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -71,7 +72,11 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
|
|
|
|
|
|
|
|
|
|
void GraphPatternDetector::operator()(Graph* graph,
|
|
|
|
|
GraphPatternDetector::handle_t handler) {
|
|
|
|
|
if (!MarkPDNodesInGraph(*graph)) return;
|
|
|
|
|
if (!MarkPDNodesInGraph(*graph)) {
|
|
|
|
|
LOG(INFO) << "Mark failed";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto subgraphs = DetectPatterns();
|
|
|
|
|
UniquePatterns(&subgraphs);
|
|
|
|
|
RemoveOverlappedMatch(&subgraphs);
|
|
|
|
@ -87,7 +92,7 @@ void GraphPatternDetector::operator()(Graph* graph,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
|
|
|
|
|
VLOG(4) << "mark pdnodes in graph";
|
|
|
|
|
VLOG(3) << "mark pdnodes in graph";
|
|
|
|
|
if (graph.Nodes().empty()) return false;
|
|
|
|
|
|
|
|
|
|
for (auto& node : GraphTraits::DFS(graph)) {
|
|
|
|
@ -107,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
|
|
|
|
|
|
|
|
|
|
return !pdnodes2nodes_.empty();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -357,7 +363,9 @@ PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
|
|
|
|
|
assert_is_op_input(op_type);
|
|
|
|
|
asserts_.emplace_back([=](Node* x) {
|
|
|
|
|
for (auto* op : x->outputs) {
|
|
|
|
|
if (IsNthInput(x, op, argument, nth)) return true;
|
|
|
|
|
if (op->IsOp() && op->Op()->Type() == op_type &&
|
|
|
|
|
IsNthInput(x, op, argument, nth))
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
});
|
|
|
|
@ -368,7 +376,9 @@ PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
|
|
|
|
|
assert_is_var();
|
|
|
|
|
asserts_.emplace_back([=](Node* x) {
|
|
|
|
|
for (auto* op : x->inputs) {
|
|
|
|
|
if (IsNthOutput(x, op, argument, nth)) return true;
|
|
|
|
|
if (op->IsOp() && op->Op()->Type() == op_type &&
|
|
|
|
|
IsNthOutput(x, op, argument, nth))
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
});
|
|
|
|
@ -412,6 +422,12 @@ PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
|
|
|
|
|
});
|
|
|
|
|
return this;
|
|
|
|
|
}
|
|
|
|
|
PDNode* PDNode::assert_is_op_output(const std::string& op_type,
|
|
|
|
|
const std::string& argument) {
|
|
|
|
|
assert_is_var();
|
|
|
|
|
assert_is_op_nth_output(op_type, argument, 0);
|
|
|
|
|
return this;
|
|
|
|
|
}
|
|
|
|
|
PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
|
|
|
|
|
assert_is_var();
|
|
|
|
|
asserts_.emplace_back([=](Node* x) {
|
|
|
|
@ -424,6 +440,12 @@ PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
|
|
|
|
|
});
|
|
|
|
|
return this;
|
|
|
|
|
}
|
|
|
|
|
PDNode* PDNode::assert_is_op_input(const std::string& op_type,
|
|
|
|
|
const std::string& argument) {
|
|
|
|
|
assert_is_var();
|
|
|
|
|
assert_is_op_nth_input(op_type, argument, 0);
|
|
|
|
|
return this;
|
|
|
|
|
}
|
|
|
|
|
PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) {
|
|
|
|
|
assert_is_op(op_type);
|
|
|
|
|
asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; });
|
|
|
|
@ -439,6 +461,128 @@ PDNode* PDNode::assert_more(PDNode::teller_t&& teller) {
|
|
|
|
|
return this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool VarLinksToOp(Node* node, const std::string& op_type) {
|
|
|
|
|
for (auto* out : node->outputs) {
|
|
|
|
|
if (out->IsOp() && out->Op()->Type() == op_type) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth) {
|
|
|
|
|
PADDLE_ENFORCE(var->IsVar());
|
|
|
|
|
PADDLE_ENFORCE(op->IsOp());
|
|
|
|
|
if (op->Op()->Input(argument).size() <= nth) return false;
|
|
|
|
|
return var->Name() == op->Op()->Input(argument)[nth];
|
|
|
|
|
}
|
|
|
|
|
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth) {
|
|
|
|
|
PADDLE_ENFORCE(var->IsVar());
|
|
|
|
|
PADDLE_ENFORCE(op->IsOp());
|
|
|
|
|
if (op->Op()->Output(argument).size() <= nth) return false;
|
|
|
|
|
return var->Name() == op->Op()->Output(argument)[nth];
|
|
|
|
|
}
|
|
|
|
|
void GraphSafeRemoveNodes(Graph* graph,
|
|
|
|
|
const std::unordered_set<const Node*>& nodes) {
|
|
|
|
|
for (auto* node : nodes) {
|
|
|
|
|
graph->RemoveNode(const_cast<Node*>(node));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto* node : graph->Nodes()) {
|
|
|
|
|
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
|
|
|
|
|
if (nodes.count(*it)) {
|
|
|
|
|
it = const_cast<Node*>(node)->inputs.erase(it);
|
|
|
|
|
} else
|
|
|
|
|
it++;
|
|
|
|
|
}
|
|
|
|
|
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
|
|
|
|
|
if (nodes.count(*it)) {
|
|
|
|
|
it = const_cast<Node*>(node)->outputs.erase(it);
|
|
|
|
|
} else
|
|
|
|
|
it++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bool VarLinksFromOp(Node* node, const std::string& op_type) {
|
|
|
|
|
for (auto* out : node->inputs) {
|
|
|
|
|
if (out->IsOp() && out->Op()->Type() == op_type) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
PDNode* x, bool with_bias) {
|
|
|
|
|
// Create Operators
|
|
|
|
|
PDNode* elementwise_add_op{nullptr};
|
|
|
|
|
auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul");
|
|
|
|
|
if (with_bias) {
|
|
|
|
|
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
|
|
|
|
|
->assert_is_op("elementwise_add");
|
|
|
|
|
}
|
|
|
|
|
// Create variables
|
|
|
|
|
// w
|
|
|
|
|
auto* mul_weight_var = pattern->NewNode(name_scope, "w")
|
|
|
|
|
->AsInput()
|
|
|
|
|
->assert_is_persistable_var()
|
|
|
|
|
->assert_is_op_nth_input("mul", "Y", 0);
|
|
|
|
|
PDNode* mul_out_var{nullptr};
|
|
|
|
|
if (with_bias) {
|
|
|
|
|
// intermediate variable, will be removed in the IR after fuse.
|
|
|
|
|
mul_out_var = pattern->NewNode(name_scope, "mul_out")
|
|
|
|
|
->AsIntermediate()
|
|
|
|
|
->assert_is_only_output_of_op("mul")
|
|
|
|
|
->assert_is_op_input("elementwise_add");
|
|
|
|
|
}
|
|
|
|
|
PDNode *bias{nullptr}, *fc_out{nullptr};
|
|
|
|
|
if (with_bias) {
|
|
|
|
|
// bias
|
|
|
|
|
bias = pattern->NewNode(name_scope, "fc_bias")
|
|
|
|
|
->assert_is_op_input("elementwise_add")
|
|
|
|
|
->AsInput();
|
|
|
|
|
// output
|
|
|
|
|
fc_out = pattern->NewNode(name_scope, "fc_out")
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output("elementwise_add");
|
|
|
|
|
} else {
|
|
|
|
|
fc_out = pattern->NewNode(name_scope, "fc_out")
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output("mul");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (with_bias) {
|
|
|
|
|
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({mul_out_var});
|
|
|
|
|
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
|
|
|
|
|
} else {
|
|
|
|
|
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return fc_out;
|
|
|
|
|
}
|
|
|
|
|
PDNode* patterns::LSTM(PDPattern* pattern, const std::string& name_scope,
|
|
|
|
|
PDNode* x) {
|
|
|
|
|
x->assert_is_op_input("lstm", "Input");
|
|
|
|
|
auto* lstm_op = pattern->NewNode(name_scope, "lstm")->assert_is_op("lstm");
|
|
|
|
|
#define NEW_NODE(arg__, io__) \
|
|
|
|
|
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
|
|
|
|
|
->assert_is_op_##io__("lstm", #arg__);
|
|
|
|
|
|
|
|
|
|
// Currently, the H0 and C0 are optional
|
|
|
|
|
// TODO(Superjomn) upgrade the fuse framework to support optional.
|
|
|
|
|
// NEW_NODE(H0, input);
|
|
|
|
|
// NEW_NODE(C0, input);
|
|
|
|
|
NEW_NODE(Weight, input);
|
|
|
|
|
NEW_NODE(Bias, input);
|
|
|
|
|
|
|
|
|
|
NEW_NODE(Hidden, output);
|
|
|
|
|
NEW_NODE(Cell, output);
|
|
|
|
|
NEW_NODE(BatchGate, output);
|
|
|
|
|
NEW_NODE(BatchCellPreAct, output);
|
|
|
|
|
|
|
|
|
|
lstm_op->LinksFrom({x, Weight, Bias});
|
|
|
|
|
lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct});
|
|
|
|
|
return Hidden;
|
|
|
|
|
}
|
|
|
|
|
} // namespace ir
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|