|
|
|
@ -18,41 +18,41 @@ struct Pattern : public PatternBase {
|
|
|
|
|
PDPattern* node_pattern() { return pattern; }
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
std::string node_name(std::string op_name)
|
|
|
|
|
{
|
|
|
|
|
std::string node_name(std::string op_name) {
|
|
|
|
|
return PDNodeName(name_scope(), repr(), id(), op_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PDNode* retrieve_node(std::string op_name)
|
|
|
|
|
{
|
|
|
|
|
PDNode* retrieve_node(std::string op_name) {
|
|
|
|
|
return node_pattern()->RetrieveNode(node_name(op_name));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PDNode* new_node(std::string op_name)
|
|
|
|
|
{
|
|
|
|
|
PDNode* new_node(std::string op_name) {
|
|
|
|
|
return node_pattern()->NewNode(node_name(op_name));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct Conv {
|
|
|
|
|
std::string conv_name() { return "conv2d"; }
|
|
|
|
|
std::string op_name() { return "conv2d"; }
|
|
|
|
|
std::string input_name() { return "Input"; }
|
|
|
|
|
std::string filter_name() { return "Filter"; }
|
|
|
|
|
std::string output_name() { return "Output"; }
|
|
|
|
|
|
|
|
|
|
std::function<PDNode* ()> operator()(std::shared_ptr<Pattern> pattern) {
|
|
|
|
|
return [&]() -> PDNode* {
|
|
|
|
|
auto conv_op = pattern->new_node(conv_name())
|
|
|
|
|
auto conv_op = pattern->new_node(op_name())
|
|
|
|
|
->assert_is_op("conv2d");
|
|
|
|
|
|
|
|
|
|
auto input_var = pattern->new_node(input_name())
|
|
|
|
|
->assert_is_op_input(conv_name(), input_name());
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
input_name());
|
|
|
|
|
|
|
|
|
|
auto filter_var = pattern->new_node(filter_name())
|
|
|
|
|
->assert_is_op_input(conv_name(), filter_name());
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
filter_name());
|
|
|
|
|
|
|
|
|
|
auto output_var = pattern->new_node(output_name())
|
|
|
|
|
->assert_is_op_output(conv_name(), output_name());
|
|
|
|
|
->assert_is_op_output(op_name(),
|
|
|
|
|
output_name());
|
|
|
|
|
|
|
|
|
|
conv_op->LinksFrom({input_var, filter_var});
|
|
|
|
|
conv_op->LinksTo({output_var});
|
|
|
|
@ -63,24 +63,27 @@ struct Conv {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ElementwiseAdd {
|
|
|
|
|
std::string elementwise_add_name() { return "elementwise_add"; }
|
|
|
|
|
std::string op_name() { return "elementwise_add"; }
|
|
|
|
|
std::string x_name() { return "X"; }
|
|
|
|
|
std::string y_name() { return "Y"; }
|
|
|
|
|
std::string out_name() { return "Out"; }
|
|
|
|
|
|
|
|
|
|
std::function<PDNode* (PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
|
|
|
|
|
return [&](PDNode* conv_output) -> PDNode* {
|
|
|
|
|
auto elementwise_add_op = pattern->new_node(elementwise_add_name())
|
|
|
|
|
auto elementwise_add_op = pattern->new_node(op_name())
|
|
|
|
|
->assert_is_op("elementwise_add");
|
|
|
|
|
|
|
|
|
|
auto y_var = pattern->new_node(y_name())
|
|
|
|
|
->assert_is_op_input(elementwise_add_name(), y_name());
|
|
|
|
|
->assert_is_op_input(op_name(),
|
|
|
|
|
y_name());
|
|
|
|
|
|
|
|
|
|
conv_output->assert_is_op_input(elementwise_add_name(), x_name());
|
|
|
|
|
conv_output->assert_is_op_input(op_name(),
|
|
|
|
|
x_name());
|
|
|
|
|
|
|
|
|
|
auto out_var = pattern->new_node(out_name())
|
|
|
|
|
->AsOutput()
|
|
|
|
|
->assert_is_op_output(elementwise_add_name(), out_name());
|
|
|
|
|
->assert_is_op_output(op_name(),
|
|
|
|
|
out_name());
|
|
|
|
|
|
|
|
|
|
elementwise_add_op->LinksFrom({y_var, conv_output});
|
|
|
|
|
elementwise_add_op->LinksTo({out_var});
|
|
|
|
@ -89,11 +92,10 @@ struct ElementwiseAdd {
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace patterns
|
|
|
|
|
|
|
|
|
|
Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
std::shared_ptr<patterns::Pattern> pattern, const std::string& op_name)
|
|
|
|
|
{
|
|
|
|
|
std::shared_ptr<patterns::Pattern> pattern,
|
|
|
|
|
const std::string& op_name) {
|
|
|
|
|
PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)),
|
|
|
|
|
"Node not found for PDNode %s", pattern->node_name(op_name));
|
|
|
|
|
Node* var = subgraph.at(pattern->retrieve_node(op_name));
|
|
|
|
@ -102,7 +104,10 @@ Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using graph_ptr = std::unique_ptr<ir::Graph>;
|
|
|
|
|
void LinkNodes(Node* from, Node* to) {
|
|
|
|
|
from->outputs.push_back(to);
|
|
|
|
|
to->inputs.push_back(from);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
for (auto& node : GraphTraits::DFS(*graph)) {
|
|
|
|
@ -112,11 +117,12 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
|
|
|
|
|
[from](Node* n) { return n == from; });
|
|
|
|
|
|
|
|
|
|
if (same != std::end(node.inputs)) {
|
|
|
|
|
node.inputs.push_back(to);
|
|
|
|
|
to->outputs.push_back(&node);
|
|
|
|
|
LinkNodes(to, &node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace patterns
|
|
|
|
|
using graph_ptr = std::unique_ptr<ir::Graph>;
|
|
|
|
|
|
|
|
|
|
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
FusePassBase::Init("conv_elementwise_add_mkldnn_fuse_pass", graph.get());
|
|
|
|
@ -133,11 +139,6 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto link_nodes_to = [](Node* a, Node* b) {
|
|
|
|
|
a->outputs.push_back(b);
|
|
|
|
|
b->inputs.push_back(a);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto fuse_conv = [&](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
@ -150,29 +151,31 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
auto fused_conv_op = g->CreateOpNode(&op_desc);
|
|
|
|
|
|
|
|
|
|
link_nodes_to(conv_input, fused_conv_op);
|
|
|
|
|
link_nodes_to(conv_filter, fused_conv_op);
|
|
|
|
|
link_nodes_to(fused_conv_op, y);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto remove_unused_nodes = [](Graph* g, const std::unordered_set<const Node*>& removed_nodes) {
|
|
|
|
|
GraphSafeRemoveNodes(g, removed_nodes);
|
|
|
|
|
patterns::LinkNodes(conv_input, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(conv_filter, fused_conv_op);
|
|
|
|
|
patterns::LinkNodes(fused_conv_op, y);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
|
|
|
|
|
auto conv_op = GetNodeFromSubgraph(subgraph, pattern_ptr, conv_pattern.conv_name());
|
|
|
|
|
auto conv_input = GetNodeFromSubgraph(subgraph, pattern_ptr, conv_pattern.input_name());
|
|
|
|
|
auto conv_filter = GetNodeFromSubgraph(subgraph, pattern_ptr, conv_pattern.filter_name());
|
|
|
|
|
auto conv_output = GetNodeFromSubgraph(subgraph, pattern_ptr, conv_pattern.output_name());
|
|
|
|
|
|
|
|
|
|
auto elementwise_add_op = GetNodeFromSubgraph(subgraph, pattern_ptr, elementwise_add_pattern.elementwise_add_name());
|
|
|
|
|
auto elementwise_add_y = GetNodeFromSubgraph(subgraph, pattern_ptr, elementwise_add_pattern.y_name());
|
|
|
|
|
auto elementwise_add_out = GetNodeFromSubgraph(subgraph, pattern_ptr, elementwise_add_pattern.out_name());
|
|
|
|
|
auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.op_name());
|
|
|
|
|
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.input_name());
|
|
|
|
|
auto conv_filter = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.filter_name());
|
|
|
|
|
auto conv_output = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
conv_pattern.output_name());
|
|
|
|
|
|
|
|
|
|
auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.op_name());
|
|
|
|
|
auto elementwise_add_y = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.y_name());
|
|
|
|
|
auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
|
|
|
|
|
elementwise_add_pattern.out_name());
|
|
|
|
|
|
|
|
|
|
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
|
|
|
|
|
CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y);
|
|
|
|
|
|
|
|
|
|
remove_unused_nodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
patterns::CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y);
|
|
|
|
|
patterns::GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
gpd(graph.get(), handler);
|
|
|
|
|