| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -50,28 +50,13 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern);  // ReLU op
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    // Create an ConvReLU Node.
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    OpDesc desc;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    std::string conv_relu_i_in = subgraph.at(conv_input)->Name();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    std::string conv_relu_w_in = conv_weight->Name();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    std::string conv_relu_b_in = conv_bias->Name();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    std::string conv_relu_out = relu_out->Name();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in}));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in}));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in}));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    desc.SetOutput("Output", std::vector<std::string>({conv_relu_out}));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    desc.SetType("conv2d");
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    for (auto& attr : conv->Op()->GetAttrMap()) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      desc.SetAttr(attr.first, attr.second);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    desc.SetAttr("fuse_relu", true);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto conv_relu_node = g->CreateOpNode(&desc);  // OpDesc will be copied.
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    OpDesc* desc = conv->Op();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    desc->SetAttr("fuse_relu", true);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    GraphSafeRemoveNodes(graph.get(), {relu, conv_out});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    PADDLE_ENFORCE(subgraph.count(conv_input));
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    IR_NODE_LINK_TO(conv_weight, conv_relu_node);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    IR_NODE_LINK_TO(conv_bias, conv_relu_node);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    IR_NODE_LINK_TO(conv_relu_node, relu_out);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    IR_NODE_LINK_TO(conv, relu_out);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    found_conv_relu_count++;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  };
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |