|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_traits.h"
|
|
|
|
|
|
|
|
|
@ -67,11 +68,32 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
|
|
|
|
|
conv_output->AsIntermediate();
|
|
|
|
|
|
|
|
|
|
auto conv_op_has_bias = [](const Node& conv_op,
|
|
|
|
|
const Scope& scope) -> std::pair<bool, Node*> {
|
|
|
|
|
auto bias_input_names = conv_op.Op()->Inputs();
|
|
|
|
|
auto bias_it = bias_input_names.find("Bias");
|
|
|
|
|
|
|
|
|
|
if (bias_it != std::end(bias_input_names)) {
|
|
|
|
|
bool has_bias = !bias_it->second.empty();
|
|
|
|
|
|
|
|
|
|
if (has_bias) {
|
|
|
|
|
auto conv_bias_names = bias_it->second;
|
|
|
|
|
auto conv_bias_names_it =
|
|
|
|
|
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
|
|
|
|
|
[&conv_bias_names](Node* n) -> bool {
|
|
|
|
|
return n->Name() == conv_bias_names[0];
|
|
|
|
|
});
|
|
|
|
|
return std::make_pair(has_bias, *conv_bias_names_it);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return std::make_pair(false, nullptr);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
|
|
|
|
Graph* g) {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
|
|
|
|
@ -81,17 +103,25 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
|
|
|
|
|
elementwise_add_pattern);
|
|
|
|
|
|
|
|
|
|
if (FindFuseOption(conv_op, elementwise_add_op) != FUSE_MKLDNN) return;
|
|
|
|
|
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
|
|
|
|
|
|
|
|
|
|
OpDesc op_desc;
|
|
|
|
|
op_desc.SetType("conv2d");
|
|
|
|
|
|
|
|
|
|
op_desc.SetInput("Input", {conv_input->Name()});
|
|
|
|
|
op_desc.SetInput("Bias", {conv_bias->Name()});
|
|
|
|
|
op_desc.SetInput("Filter", {conv_filter->Name()});
|
|
|
|
|
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
|
|
|
|
|
op_desc.SetOutput("Output", {conv_output->Name()});
|
|
|
|
|
|
|
|
|
|
bool has_bias;
|
|
|
|
|
Node* conv_bias;
|
|
|
|
|
|
|
|
|
|
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op, *param_scope());
|
|
|
|
|
|
|
|
|
|
if (has_bias) {
|
|
|
|
|
op_desc.SetInput("Bias", {conv_bias->Name()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
|
|
|
|
|
op_desc.SetAttr(attr.first, attr.second);
|
|
|
|
|
}
|
|
|
|
@ -101,11 +131,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
|
|
|
|
|
auto fused_conv_op = g->CreateOpNode(&op_desc);
|
|
|
|
|
|
|
|
|
|
IR_NODE_LINK_TO(conv_input, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
|
|
|
|
|
IR_NODE_LINK_TO(fused_conv_op, conv_output);
|
|
|
|
|
|
|
|
|
|
if (has_bias) {
|
|
|
|
|
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CorrectGraphEdges(g, elementwise_add_out, conv_output);
|
|
|
|
|
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
|
|
|
|
|
};
|
|
|
|
|