|
|
|
@ -50,28 +50,13 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
|
|
|
|
|
|
|
|
|
|
// Create an ConvReLU Node.
|
|
|
|
|
OpDesc desc;
|
|
|
|
|
std::string conv_relu_i_in = subgraph.at(conv_input)->Name();
|
|
|
|
|
std::string conv_relu_w_in = conv_weight->Name();
|
|
|
|
|
std::string conv_relu_b_in = conv_bias->Name();
|
|
|
|
|
std::string conv_relu_out = relu_out->Name();
|
|
|
|
|
desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in}));
|
|
|
|
|
desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in}));
|
|
|
|
|
desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in}));
|
|
|
|
|
desc.SetOutput("Output", std::vector<std::string>({conv_relu_out}));
|
|
|
|
|
desc.SetType("conv2d");
|
|
|
|
|
for (auto& attr : conv->Op()->GetAttrMap()) {
|
|
|
|
|
desc.SetAttr(attr.first, attr.second);
|
|
|
|
|
}
|
|
|
|
|
desc.SetAttr("fuse_relu", true);
|
|
|
|
|
auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied.
|
|
|
|
|
GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out});
|
|
|
|
|
OpDesc* desc = conv->Op();
|
|
|
|
|
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
|
|
|
|
|
desc->SetAttr("fuse_relu", true);
|
|
|
|
|
GraphSafeRemoveNodes(graph.get(), {relu, conv_out});
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(subgraph.count(conv_input));
|
|
|
|
|
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node);
|
|
|
|
|
IR_NODE_LINK_TO(conv_weight, conv_relu_node);
|
|
|
|
|
IR_NODE_LINK_TO(conv_bias, conv_relu_node);
|
|
|
|
|
IR_NODE_LINK_TO(conv_relu_node, relu_out);
|
|
|
|
|
IR_NODE_LINK_TO(conv, relu_out);
|
|
|
|
|
|
|
|
|
|
found_conv_relu_count++;
|
|
|
|
|
};
|
|
|
|
|