|
|
|
@ -42,14 +42,13 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
|
|
|
|
|
Graph* g) {
|
|
|
|
|
VLOG(4) << "handle ConvReLU fuse";
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
|
|
|
|
|
conv_relu_pattern); // Filter
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
|
|
|
|
|
conv_relu_pattern); // Filter
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
|
|
|
|
|
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
|
|
|
|
|
|
|
|
|
|
// Create an ConvReLU Node.
|
|
|
|
|
// Transform Conv node into ConvReLU node.
|
|
|
|
|
OpDesc* desc = conv->Op();
|
|
|
|
|
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
|
|
|
|
|
desc->SetAttr("fuse_relu", true);
|
|
|
|
|