|
|
|
@ -26,8 +26,6 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
|
|
|
|
|
PADDLE_ENFORCE(graph.get());
|
|
|
|
|
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
|
|
|
|
|
|
|
|
|
|
std::unordered_set<Node*> nodes2delete;
|
|
|
|
|
|
|
|
|
|
GraphPatternDetector gpd;
|
|
|
|
|
auto* conv_input = gpd.mutable_pattern()
|
|
|
|
|
->NewNode("conv_relu_mkldnn_fuse/conv_input")
|
|
|
|
@ -42,36 +40,20 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(4) << "handle ConvReLU fuse";
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
|
|
|
|
|
conv_relu_pattern); // Filter
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
|
|
|
|
|
conv_relu_pattern); // Filter
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
|
|
|
|
|
|
|
|
|
|
// Create an ConvReLU Node.
|
|
|
|
|
OpDesc desc;
|
|
|
|
|
std::string conv_relu_i_in = subgraph.at(conv_input)->Name();
|
|
|
|
|
std::string conv_relu_w_in = conv_weight->Name();
|
|
|
|
|
std::string conv_relu_b_in = conv_bias->Name();
|
|
|
|
|
std::string conv_relu_out = relu_out->Name();
|
|
|
|
|
desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in}));
|
|
|
|
|
desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in}));
|
|
|
|
|
desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in}));
|
|
|
|
|
desc.SetOutput("Output", std::vector<std::string>({conv_relu_out}));
|
|
|
|
|
desc.SetType("conv2d");
|
|
|
|
|
for (auto& attr : conv->Op()->GetAttrMap()) {
|
|
|
|
|
desc.SetAttr(attr.first, attr.second);
|
|
|
|
|
}
|
|
|
|
|
desc.SetAttr("fuse_relu", true);
|
|
|
|
|
auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied.
|
|
|
|
|
GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out});
|
|
|
|
|
// Transform Conv node into ConvReLU node.
|
|
|
|
|
OpDesc* desc = conv->Op();
|
|
|
|
|
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
|
|
|
|
|
desc->SetAttr("fuse_relu", true);
|
|
|
|
|
GraphSafeRemoveNodes(graph.get(), {relu, conv_out});
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(subgraph.count(conv_input));
|
|
|
|
|
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node);
|
|
|
|
|
IR_NODE_LINK_TO(conv_weight, conv_relu_node);
|
|
|
|
|
IR_NODE_LINK_TO(conv_bias, conv_relu_node);
|
|
|
|
|
IR_NODE_LINK_TO(conv_relu_node, relu_out);
|
|
|
|
|
IR_NODE_LINK_TO(conv, relu_out);
|
|
|
|
|
|
|
|
|
|
found_conv_relu_count++;
|
|
|
|
|
};
|
|
|
|
|