From 07a62ddc08aaaa80f4fe934d9dc8b40870970018 Mon Sep 17 00:00:00 2001
From: Tomasz Patejko <tomasz.patejko@intel.com>
Date: Mon, 17 Sep 2018 04:41:26 +0200
Subject: [PATCH] MKLDNN conv + elementwise_add fusion: inputs in pass
 modified. Support for new conv parameter. UTs corrected

---
 .../conv_elementwise_add_mkldnn_fuse_pass.cc  | 25 ++++++++++---------
 ...elementwise_add_mkldnn_fuse_pass_tester.cc | 15 ++++++-----
 2 files changed, 22 insertions(+), 18 deletions(-)

diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
index 0e37bf9634..f2ff0bf13b 100644
--- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.cc
@@ -73,19 +73,19 @@ struct ElementwiseAdd {
       auto elementwise_add_op = pattern->new_node(op_name())
                                        ->assert_is_op("elementwise_add");
 
-      auto y_var = pattern->new_node(y_name())
+      auto x_var = pattern->new_node(x_name())
                           ->assert_is_op_input(op_name(),
-                                               y_name());
+                                               x_name());
   
       conv_output->assert_is_op_input(op_name(),
-                                      x_name());
+                                      y_name());
 
       auto out_var = pattern->new_node(out_name())
                             ->AsOutput()
                             ->assert_is_op_output(op_name(),
                                                   out_name());
 
-      elementwise_add_op->LinksFrom({y_var, conv_output});
+      elementwise_add_op->LinksFrom({x_var, conv_output});
       elementwise_add_op->LinksTo({out_var});
 
       return out_var;
@@ -139,13 +139,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
 
   conv_output->AsIntermediate();
 
-  auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
+  auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* conv_output, Node* elementwise_add_x) {
     OpDesc op_desc;
     op_desc.SetType("conv2d");
 
     op_desc.SetInput("Input", {conv_input->Name()});
     op_desc.SetInput("Filter", {conv_filter->Name()});
-    op_desc.SetOutput("Output", {y->Name()});
+    op_desc.SetInput("ElementwiseParameter", {elementwise_add_x->Name()});
+    op_desc.SetOutput("Output", {conv_output->Name()});
 
     op_desc.SetAttr("use_mkldnn", true);
     op_desc.SetAttr("fuse_eltwise", true);
@@ -154,7 +155,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
 
     patterns::LinkNodes(conv_input, fused_conv_op);
     patterns::LinkNodes(conv_filter, fused_conv_op);
-    patterns::LinkNodes(fused_conv_op, y);
+    patterns::LinkNodes(fused_conv_op, conv_output);
   };
 
   auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
@@ -169,14 +170,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
 
     auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
                                                   elementwise_add_pattern.op_name());
-    auto elementwise_add_y = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
-                                                 elementwise_add_pattern.y_name());
+    auto elementwise_add_x = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
+                                                 elementwise_add_pattern.x_name());
     auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
                                                    elementwise_add_pattern.out_name());
 
-    fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
-    patterns::CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y);
-    GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
+    fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x);
+    patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
+    GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
   };
 
   gpd(graph.get(), handler);
diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
index e60a916b1d..17de916c63 100644
--- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
+++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc
@@ -8,6 +8,9 @@ namespace paddle {
 namespace framework {
 namespace ir {
 
+constexpr int nodes_removed = 3;
+constexpr int nodes_added = 1;
+
 void SetOp(ProgramDesc* prog, const std::string& type,
            const std::vector<std::string>& inputs,
            const std::vector<std::string>& outputs) {
@@ -93,7 +96,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
     }
   
     SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
-    SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
+    SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
     SetOp(&prog, "relu", {"d"}, {"e"});
 
     return prog;
@@ -113,7 +116,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
 
   EXPECT_TRUE(is_reachable(graph)("a", "relu"));
 
-  EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
+  EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
   // Assert conv_relu op in newly generated graph
   int conv_count = 0;
   int elementwise_add_count = 0;
@@ -143,7 +146,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
     }
   
     SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
-    SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
+    SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
 
     return prog;
   };
@@ -161,7 +164,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
 
   EXPECT_FALSE(is_reachable(graph)("a", "d"));
  
-  EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
+  EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
   // Assert conv_relu op in newly generated graph
   int conv_count = 0;
   int elementwise_add_count = 0;
@@ -192,7 +195,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
   
     SetOp(&prog, "sigmoid", {"a"}, {"b"});
     SetOp(&prog, "conv2d", {"b", "weights"}, {"c"});
-    SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"});
+    SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"});
     SetOp(&prog, "relu", {"e"}, {"f"});
 
     return prog;
@@ -212,7 +215,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
 
   EXPECT_TRUE(is_reachable(graph)("a", "f"));
 
-  EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
+  EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
   // Assert conv_relu op in newly generated graph
   int conv_count = 0;
   int elementwise_add_count = 0;