From 6ebe9877bb2d187b24b31e0ded7c3c63930a57dd Mon Sep 17 00:00:00 2001
From: Michal Gallus <michal.gallus@intel.com>
Date: Mon, 25 Feb 2019 10:23:24 +0100
Subject: [PATCH] Improve code reuse at MKL-DNN sum

test=develop
---
 .../fluid/operators/mkldnn/sum_mkldnn_op.cc   | 112 +-----------------
 1 file changed, 4 insertions(+), 108 deletions(-)

diff --git a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
index fe4131df2c..6f64157b64 100644
--- a/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
+++ b/paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc
@@ -79,15 +79,6 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 
       memory::format input_format = input0.format();
 
-      if (src_tz.size() == 1 && (input_format == memory::format::nchw ||
-                                 input_format == memory::format::nhwc)) {
-        input_format = memory::format::x;
-      }
-      if (src_tz.size() == 2 && (input_format == memory::format::nchw ||
-                                 input_format == memory::format::nhwc)) {
-        input_format = memory::format::nc;
-      }
-
       for (int i = 0; i < N; i++) {
         PADDLE_ENFORCE(in_vars[i]->IsType<LoDTensor>(),
                        "all inputs must be all LoDTensors");
@@ -147,105 +138,10 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
 
       output->set_layout(DataLayout::kMKLDNN);
       output->set_format(output_format);
-    } else if (out_var->IsType<framework::SelectedRows>()) {
-      // TODO(@mozga-intel) Add MKLDNN SelectedRows support
-      std::unique_ptr<framework::SelectedRows> in0;
-      if (in_place) {
-        // If is in_place, we store the input[0] to in0
-        auto& in_sel0 = in_vars[0]->Get<SelectedRows>();
-        auto& rows = in_sel0.rows();
-        in0.reset(new framework::SelectedRows(rows, in_sel0.height()));
-        in0->mutable_value()->ShareDataWith(in_sel0.value());
-      }
-
-      auto get_selected_row = [&](size_t i) -> const SelectedRows& {
-        if (i == 0 && in0) {
-          return *in0;
-        } else {
-          return in_vars[i]->Get<SelectedRows>();
-        }
-      };
-      auto* out = ctx.Output<SelectedRows>("Out");
-      out->mutable_rows()->clear();
-      auto* out_value = out->mutable_value();
-
-      // Runtime InferShape
-      size_t first_dim = 0;
-      for (int i = 0; i < N; i++) {
-        auto& sel_row = get_selected_row(i);
-        first_dim += sel_row.rows().size();
-      }
-
-      std::vector<int64_t> in_dim;
-      for (int i = 0; i < N; i++) {
-        auto& sel_row = get_selected_row(i);
-        if (sel_row.rows().size() > 0) {
-          in_dim = framework::vectorize(sel_row.value().dims());
-          break;
-        }
-      }
-
-      if (in_dim.empty()) {
-        VLOG(3) << "WARNING: all the inputs are empty";
-        in_dim = framework::vectorize(get_selected_row(N - 1).value().dims());
-      } else {
-        in_dim[0] = static_cast<int64_t>(first_dim);
-      }
-
-      in_dim[0] = static_cast<int64_t>(first_dim);
-
-      out_value->Resize(framework::make_ddim(in_dim));
-
-      out_value->mutable_data<T>(ctx.GetPlace());
-
-      // if all the input sparse vars are empty, no need to
-      // merge these vars.
-      if (first_dim == 0UL) {
-        return;
-      }
-
-      math::SelectedRowsAddTo<CPUDeviceContext, T> functor;
-      int64_t offset = 0;
-      for (int i = 0; i < N; i++) {
-        auto& sel_row = get_selected_row(i);
-        if (sel_row.rows().size() == 0) {
-          continue;
-        }
-        PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
-        functor(ctx.template device_context<CPUDeviceContext>(), sel_row,
-                offset, out);
-        offset += sel_row.value().numel();
-      }
-    } else if (out_var->IsType<framework::LoDTensorArray>()) {
-      // TODO(@mozga-intel) Add MKLDNN LoDTensorArray support
-      auto& out_array = *out_var->GetMutable<framework::LoDTensorArray>();
-      for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
-        PADDLE_ENFORCE(in_vars[i]->IsType<framework::LoDTensorArray>(),
-                       "Only support all inputs are TensorArray");
-        auto& in_array = in_vars[i]->Get<framework::LoDTensorArray>();
-
-        for (size_t i = 0; i < in_array.size(); ++i) {
-          if (in_array[i].numel() != 0) {
-            if (i >= out_array.size()) {
-              out_array.resize(i + 1);
-            }
-            if (out_array[i].numel() == 0) {
-              framework::TensorCopy(in_array[i], in_array[i].place(),
-                                    ctx.device_context(), &out_array[i]);
-              out_array[i].set_lod(in_array[i].lod());
-            } else {
-              PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
-              auto in = EigenVector<T>::Flatten(in_array[i]);
-              auto result = EigenVector<T>::Flatten(out_array[i]);
-              result.device(*ctx.template device_context<MKLDNNDeviceContext>()
-                                 .eigen_device()) = result + in;
-            }
-          }
-        }
-      }
-    } else {
-      PADDLE_THROW("Unexpected branch, output variable type is %s",
-                   framework::ToTypeName(out_var->Type()));
+    } else {  // Fallback to naive version
+      // TODO(@mozga-intel) Add MKLDNN SelectedRows & LoDTensorArray support
+      SumKernel<CPUDeviceContext, T> reference_kernel;
+      reference_kernel.Compute(ctx);
     }
   }
 };