From 595a2c83aee55cc7483abc9e82b5853c58931a99 Mon Sep 17 00:00:00 2001
From: dzhwinter <dongzhihong01@baidu.com>
Date: Wed, 1 Aug 2018 14:48:37 +0800
Subject: [PATCH] explicit gradient of elementwise_add/elementwise_sub (#11970)

* "add gradient register"

* "make some enhance"

* "better format"

* "fix typo"

* "fix reuse"

* "fix get expected kernel"

* "change the mkldnn code"

* "fix mkldnn"

* "fix mkldnn failed test"

* "add comment"
---
 paddle/fluid/framework/op_proto_maker.cc      |  34 ++++
 paddle/fluid/framework/op_proto_maker.h       |   2 +
 paddle/fluid/framework/op_proto_maker_test.cc | 107 +++++++++-
 .../operators/elementwise_add_mkldnn_op.cc    |  47 ++---
 paddle/fluid/operators/elementwise_add_op.cc  |   4 +-
 paddle/fluid/operators/elementwise_add_op.h   |  16 +-
 paddle/fluid/operators/elementwise_div_op.cc  |   2 +
 paddle/fluid/operators/elementwise_op.h       |  74 ++++++-
 .../fluid/operators/elementwise_op_function.h | 186 ++++++++++++------
 paddle/fluid/operators/elementwise_sub_op.cc  |   5 +-
 paddle/fluid/operators/elementwise_sub_op.h   |  11 +-
 paddle/fluid/operators/softmax_op.cc          |  27 ++-
 .../unittests/test_elementwise_sub_op.py      |   4 +-
 13 files changed, 405 insertions(+), 114 deletions(-)

diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc
index 001b5cb5a8..2288c7fe66 100644
--- a/paddle/fluid/framework/op_proto_maker.cc
+++ b/paddle/fluid/framework/op_proto_maker.cc
@@ -40,6 +40,40 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
   return OpProtoAndCheckerMaker::VariableBuilder{output};
 }
 
+void OpProtoAndCheckerMaker::Reuse(const std::string& name,
+                                   const std::string& reused_name) {
+  bool found = false;
+  proto::OpProto::Var* var;
+
+  for (auto& var : proto_->inputs()) {
+    if (var.name() == reused_name) {
+      found = true;
+      break;
+    }
+  }
+  PADDLE_ENFORCE(found == true,
+                 "Input/Output name: %s reused_name: %s, one of them is not "
+                 "exists or not matched.",
+                 name, reused_name);
+
+  found = false;
+  for (int i = 0; i < proto_->outputs().size(); ++i) {
+    var = proto_->mutable_outputs()->Mutable(i);
+    if (var->name() == name) {
+      PADDLE_ENFORCE(!var->has_reuse(),
+                     "Output(%s) has been set reused var of %s", name,
+                     var->reuse());
+      found = true;
+      var->set_reuse(reused_name);
+      break;
+    }
+  }
+  PADDLE_ENFORCE(found == true,
+                 "Input/Output name: %s reused_name: %s, one of them is not "
+                 "exists or not matched.",
+                 name, reused_name);
+}
+
 void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
   std::unordered_set<std::string> names;
   auto checker = [&](const std::string& name) {
diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h
index 92f86bb5de..80970291c9 100644
--- a/paddle/fluid/framework/op_proto_maker.h
+++ b/paddle/fluid/framework/op_proto_maker.h
@@ -78,6 +78,8 @@ class OpProtoAndCheckerMaker {
   VariableBuilder AddOutput(const std::string &name,
                             const std::string &comment);
 
+  void Reuse(const std::string &name, const std::string &reused_name);
+
   template <typename T>
   TypedAttrChecker<T> &AddAttr(const std::string &name,
                                const std::string &comment,
diff --git a/paddle/fluid/framework/op_proto_maker_test.cc b/paddle/fluid/framework/op_proto_maker_test.cc
index 58f70cb39c..b71c7b6468 100644
--- a/paddle/fluid/framework/op_proto_maker_test.cc
+++ b/paddle/fluid/framework/op_proto_maker_test.cc
@@ -49,6 +49,15 @@ TEST(ProtoMaker, DuplicatedInOut) {
 }
 
 class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() {
+    AddInput("X", "input of test op");
+    AddOutput("XOut", "output of test op").Reuse("X");
+  }
+};
+
+class TestInplaceProtoMaker2
+    : public paddle::framework::OpProtoAndCheckerMaker {
  public:
   void Make() {
     AddInput("X", "input of test op");
@@ -58,12 +67,100 @@ class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
 };
 
 TEST(ProtoMaker, InplaceOutput) {
-  paddle::framework::proto::OpProto op_proto;
+  paddle::framework::proto::OpProto op_proto, op_proto2;
   paddle::framework::OpAttrChecker op_checker;
   TestInplaceProtoMaker proto_maker;
-  ASSERT_THROW(proto_maker(&op_proto, &op_checker),
+  TestInplaceProtoMaker2 proto_maker2;
+
+  proto_maker(&op_proto, &op_checker);
+
+  ASSERT_THROW(proto_maker2(&op_proto2, &op_checker),
                paddle::platform::EnforceNotMet);
-  // proto_maker(&op_proto, &op_checker);
-  // proto_maker.Make();
-  // ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
 }
+
+// normal reuse
+class TestReuseProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
+ public:
+  void Make() {
+    AddInput("X", "input of test op");
+    AddInput("Y", "input of test op");
+    AddOutput("Out", "output of test op");
+    AddOutput("XOut", "output of test op");
+    // avoid destructor exception.
+    // Validate();
+    TestReuse();
+  }
+
+  virtual void TestReuse() {}
+};
+
+// test duplicate reuse error
+class TestReuseProtoMaker2 : public TestReuseProtoMaker {
+ public:
+  void TestReuse() {
+    Reuse("Out", "X");
+    Reuse("Out", "Y");
+  }
+};
+
+// NotExists Input
+class TestReuseProtoMaker3 : public TestReuseProtoMaker {
+ public:
+  void TestReuse() {
+    Reuse("Out", "NotExists");
+    Reuse("XOut", "X");
+  }
+};
+
+// NotExists Output
+class TestReuseProtoMaker4 : public TestReuseProtoMaker {
+ public:
+  void TestReuse() { Reuse("NotExists", "X"); }
+};
+
+TEST(ProtoMaker, Reuse) {
+  paddle::framework::proto::OpProto op_proto;
+  paddle::framework::OpAttrChecker op_checker;
+  TestReuseProtoMaker proto_maker;
+  proto_maker(&op_proto, &op_checker);
+}
+
+// NOTE(dzhwinter):
+// There is a Fatal CHECK on base class destructor, which will call abort inside
+// instead of
+// throw an exception. If we throw an exception in Make(), we will trigger the
+// CHECK and terminate the tests.
+//
+// I had tried to replace the default CHECK with a exception, however, it's
+// still not supported by glog.
+// the details:
+// https://github.com/google/glog/issues/249
+// https://github.com/facebookresearch/TensorComprehensions/issues/351
+/*
+TEST(ProtoMaker, ReuseWithException) {
+  paddle::framework::proto::OpProto op_proto2, op_proto3, op_proto4;
+  paddle::framework::OpAttrChecker op_checker;
+  TestReuseProtoMaker2 proto_maker2;
+  TestReuseProtoMaker3 proto_maker3;
+  TestReuseProtoMaker4 proto_maker4;
+  EXPECT_THROW(proto_maker2(&op_proto2, &op_checker),
+               paddle::platform::EnforceNotMet);
+
+  EXPECT_THROW(proto_maker3(&op_proto3, &op_checker),
+               paddle::platform::EnforceNotMet);
+
+  EXPECT_THROW(proto_maker4(&op_proto4, &op_checker),
+               paddle::platform::EnforceNotMet);
+}
+
+void FailureFunction() {
+  throw std::runtime_error("Check failed in destructor.");
+  // return 0;
+}
+
+int main(int argc, char** argv) {
+  testing::InitGoogleTest(&argc, argv);
+  google::InstallFailureFunction(&FailureFunction);
+  return RUN_ALL_TESTS();
+}
+*/
diff --git a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc
index 1a5427b392..c86cd57316 100644
--- a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc
+++ b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc
@@ -47,12 +47,12 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
     int axis = ctx.Attr<int>("axis");
 
     auto x_dims = x->dims();
-    auto y_dims = y->dims();
+    auto y_dims_untrimed = y->dims();
     auto z_dims = z->dims();
 
     // Execute default elementwise_add operator when
     // broadcast operations need to performed.
-    if (x_dims != y_dims) {
+    if (x_dims != y_dims_untrimed) {
       auto sum_func = [](T a, T b) -> T { return a + b; };
 
       TransformFunctor<decltype(sum_func), T,
@@ -62,11 +62,11 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
               ctx.template device_context<paddle::platform::CPUDeviceContext>(),
               sum_func);
 
-      axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
+      axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
       PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                      "Axis should be in range [0, x_dims)");
 
-      trim_trailing_singular_dims(&y_dims);
+      auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
       axis = (y_dims.size() == 0) ? x_dims.size() : axis;
 
       int pre, n, post;
@@ -88,7 +88,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
                      "Wrong layout/format set for Y tensor");
 
       std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
-      std::vector<int> src_y_tz = framework::vectorize2int(y_dims);
+      std::vector<int> src_y_tz = framework::vectorize2int(y_dims_untrimed);
       std::vector<int> dst_tz = framework::vectorize2int(z_dims);
 
       std::vector<memory::primitive_desc> srcs_pd;
@@ -142,36 +142,39 @@ class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> {
   void Compute(const framework::ExecutionContext& ctx) const override {
     using Tensor = framework::Tensor;
 
-    auto* x = ctx.Input<Tensor>("X");
-    auto* y = ctx.Input<Tensor>("Y");
-    auto* out = ctx.Input<Tensor>("Out");
     auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
     auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
     auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
     int axis = ctx.Attr<int>("axis");
+    // skip out, x, y,
+    // dout length is larger or equal than dx, dy.
+    auto* out = dout;
+    auto *x = dout, *y = dout;
 
     auto set_mkldnn_format = [](Tensor* in, const Tensor* out) {
       in->set_layout(DataLayout::kMKLDNN);
       in->set_format(out->format());
     };
 
-    if (x->dims() == y->dims()) {
-      auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
-      if (dx) {
-        blas.VCOPY(dout->numel(), dout->data<T>(),
-                   dx->mutable_data<T>(ctx.GetPlace()));
-        set_mkldnn_format(dx, dout);
-      }
-
-      if (dy) {
-        blas.VCOPY(dout->numel(), dout->data<T>(),
-                   dy->mutable_data<T>(ctx.GetPlace()));
-        set_mkldnn_format(dy, dout);
+    if (dx != nullptr && dy != nullptr && dx->dims() == dy->dims()) {
+      if (dx->dims() == dy->dims()) {
+        auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
+        if (dx) {
+          blas.VCOPY(dout->numel(), dout->data<T>(),
+                     dx->mutable_data<T>(ctx.GetPlace()));
+          set_mkldnn_format(dx, dout);
+        }
+
+        if (dy) {
+          blas.VCOPY(dout->numel(), dout->data<T>(),
+                     dy->mutable_data<T>(ctx.GetPlace()));
+          set_mkldnn_format(dy, dout);
+        }
       }
     } else {
       // Execute default kernel when broadcast is needed
-      ElemwiseGradCompute<paddle::platform::CPUDeviceContext, T,
-                          IdentityGrad<T>, IdentityGrad<T>>(
+      ElemwiseExplicitGradCompute<paddle::platform::CPUDeviceContext, T,
+                                  IdentityGrad<T>, IdentityGrad<T>>(
           ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
           IdentityGrad<T>());
     }
diff --git a/paddle/fluid/operators/elementwise_add_op.cc b/paddle/fluid/operators/elementwise_add_op.cc
index d2c2053713..3c97ac995c 100644
--- a/paddle/fluid/operators/elementwise_add_op.cc
+++ b/paddle/fluid/operators/elementwise_add_op.cc
@@ -15,7 +15,9 @@ limitations under the License. */
 #include "paddle/fluid/operators/elementwise_add_op.h"
 #include "paddle/fluid/operators/elementwise_op.h"
 namespace ops = paddle::operators;
-REGISTER_ELEMWISE_OP(elementwise_add, "Add", "Out = X + Y");
+REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add);
+REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y", "Out",
+                              "X");
 REGISTER_OP_CPU_KERNEL(
     elementwise_add,
     ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h
index baf04c30b1..5356105e2e 100644
--- a/paddle/fluid/operators/elementwise_add_op.h
+++ b/paddle/fluid/operators/elementwise_add_op.h
@@ -95,9 +95,10 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
                                   framework::Tensor* dy) {
   int axis = ctx.Attr<int>("axis");
 
-  ElemwiseGradCompute<DeviceContext, T, IdentityGrad<T>, IdentityGrad<T>>(
-      ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
-      IdentityGrad<T>());
+  ElemwiseExplicitGradCompute<DeviceContext, T, IdentityGrad<T>,
+                              IdentityGrad<T>>(ctx, *x, *y, *out, *dout, axis,
+                                               dx, dy, IdentityGrad<T>(),
+                                               IdentityGrad<T>());
 }
 
 template <typename DeviceContext, typename T>
@@ -140,14 +141,15 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
   void Compute(const framework::ExecutionContext& ctx) const override {
     using Tensor = framework::Tensor;
 
-    auto* x = ctx.Input<Tensor>("X");
-    auto* y = ctx.Input<Tensor>("Y");
-    auto* out = ctx.Input<Tensor>("Out");
     auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
     auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
     auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
+    // skip out, x, y
+    auto* out = dout;
+    auto *x = dout, *y = dout;
 
-    if (platform::is_cpu_place(ctx.GetPlace()) && (x->dims() == y->dims())) {
+    if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr &&
+        dy != nullptr && (dx->dims() == dy->dims())) {
       elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
     } else {
       default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
diff --git a/paddle/fluid/operators/elementwise_div_op.cc b/paddle/fluid/operators/elementwise_div_op.cc
index 824b1221e5..84c8a65e5f 100644
--- a/paddle/fluid/operators/elementwise_div_op.cc
+++ b/paddle/fluid/operators/elementwise_div_op.cc
@@ -15,7 +15,9 @@ limitations under the License. */
 #include "paddle/fluid/operators/elementwise_div_op.h"
 #include "paddle/fluid/operators/elementwise_op.h"
 namespace ops = paddle::operators;
+
 REGISTER_ELEMWISE_OP(elementwise_div, "Div", "Out = X / Y");
+
 REGISTER_OP_CPU_KERNEL(
     elementwise_div,
     ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h
index bb88970e42..d8a12e800a 100644
--- a/paddle/fluid/operators/elementwise_op.h
+++ b/paddle/fluid/operators/elementwise_op.h
@@ -78,7 +78,9 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
   void Make() final {
     AddInput("X", "(Tensor), The first input tensor of elementwise op.");
     AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
-    AddOutput("Out", "The output of elementwise op.").Reuse("X");
+    // AddOutput("SavedShape", "(Tensor), save X, Y shape for grad to save
+    // memory.").AsIntermediate();
+    AddOutput("Out", "The output of elementwise op.");
     AddAttr<int>("axis",
                  "(int, default -1). The start dimension index "
                  "for broadcasting Y onto X.")
@@ -125,11 +127,13 @@ But the output only shares the LoD information with the input $X$.
 
 )DOC",
                                GetName(), GetEquation()));
+    SetReuse();
   }
 
  protected:
   virtual std::string GetName() const = 0;
   virtual std::string GetEquation() const = 0;
+  virtual void SetReuse() {}
 };
 
 class ElementwiseOpGrad : public framework::OperatorWithKernel {
@@ -162,8 +166,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
 
   framework::OpKernelType GetExpectedKernelType(
       const framework::ExecutionContext& ctx) const override {
-    auto input_data_type =
-        framework::ToDataType(ctx.Input<Tensor>("X")->type());
+    auto input_data_type = framework::ToDataType(
+        ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
 
 #ifdef PADDLE_WITH_MKLDNN
     if (platform::CanMKLDNNBeUsed(ctx)) {
@@ -175,9 +179,58 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
     return framework::OpKernelType(input_data_type, ctx.GetPlace());
   }
 };
+
+// For Add, Sub op, the X, Out is not needed.
+class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
+ public:
+  using operators::ElementwiseOpGrad::ElementwiseOpGrad;
+  using operators::ElementwiseOpGrad::GetExpectedKernelType;
+  using Tensor = framework::Tensor;
+
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
+                   "Input(Out@GRAD) should not be null");
+
+    auto x_grad_name = framework::GradVarName("X");
+    if (ctx->HasOutput(x_grad_name)) {
+      auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
+      ctx->SetOutputDim(x_grad_name, out_dims);
+    }
+    auto y_grad_name = framework::GradVarName("Y");
+    if (ctx->HasOutput(y_grad_name)) {
+      PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
+      auto y_dims = ctx->GetInputDim("Y");
+      ctx->SetOutputDim(y_grad_name, y_dims);
+    }
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
 
+/*
+*/
+
+#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name)                   \
+  class kernel_type##GradMaker                                               \
+      : public paddle::framework::SingleGradOpDescMaker {                    \
+   public:                                                                   \
+    using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
+                                                                             \
+   protected:                                                                \
+    std::unique_ptr<paddle::framework::OpDesc> Apply() const override {      \
+      auto* op = new paddle::framework::OpDesc();                            \
+      op->SetType(#kernel_type "_grad");                                     \
+      op->SetInput("Y", Input("Y"));                                         \
+      op->SetInput(::paddle::framework::GradVarName("Out"),                  \
+                   OutputGrad("Out"));                                       \
+      op->SetAttrMap(Attrs());                                               \
+      op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X"));  \
+      op->SetOutput(::paddle::framework::GradVarName("Y"), InputGrad("Y"));  \
+      return std::unique_ptr<::paddle::framework::OpDesc>(op);               \
+    }                                                                        \
+  }
+
 #define REGISTER_ELEMWISE_OP(op_type, op_name, equation)                \
   class __ElemwiseOp##op_type##Maker__                                  \
       : public ::paddle::operators::ElementwiseOpMaker {                \
@@ -190,3 +243,18 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
                     ::paddle::operators::ElementwiseOpInferVarType,     \
                     ::paddle::framework::DefaultGradOpDescMaker<true>); \
   REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
+
+#define REGISTER_ELEMWISE_EXPLICIT_OP(op_type, op_name, equation, ...) \
+  class __ElemwiseOp##op_type##Maker__                                 \
+      : public ::paddle::operators::ElementwiseOpMaker {               \
+   protected:                                                          \
+    virtual std::string GetName() const { return op_name; }            \
+    virtual std::string GetEquation() const { return equation; }       \
+    virtual void SetReuse() { Reuse(__VA_ARGS__); }                    \
+  };                                                                   \
+  REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp,       \
+                    __ElemwiseOp##op_type##Maker__,                    \
+                    ::paddle::operators::ElementwiseOpInferVarType,    \
+                    op_type##GradMaker);                               \
+  REGISTER_OPERATOR(op_type##_grad,                                    \
+                    ::paddle::operators::ElementwiseOpExplicitGrad)
diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h
index 8b052611f8..eb8272e90c 100644
--- a/paddle/fluid/operators/elementwise_op_function.h
+++ b/paddle/fluid/operators/elementwise_op_function.h
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #pragma once
+#include <glog/logging.h>
 #include <algorithm>
+#include <vector>
 #include "paddle/fluid/framework/eigen.h"
 #include "paddle/fluid/framework/op_registry.h"
 #include "paddle/fluid/framework/operator.h"
@@ -65,17 +67,21 @@ inline void get_mid_dims(const framework::DDim& x_dims,
   }
 }
 
-inline void trim_trailing_singular_dims(framework::DDim* dims) {
+inline framework::DDim trim_trailing_singular_dims(
+    const framework::DDim& dims) {
   // Remove trailing dimensions of size 1 for y
-  auto actual_dims_size = dims->size();
+  auto actual_dims_size = dims.size();
   for (; actual_dims_size != 0; --actual_dims_size) {
-    if ((*dims)[actual_dims_size - 1] != 1) break;
+    if (dims[actual_dims_size - 1] != 1) break;
   }
-  if (actual_dims_size != dims->size()) {
-    auto actual_dims = framework::vectorize(*dims);
-    actual_dims.resize(actual_dims_size);
-    *dims = framework::make_ddim(actual_dims);
+
+  std::vector<int> trim_dims;
+  trim_dims.resize(actual_dims_size);
+  for (int i = 0; i < actual_dims_size; ++i) {
+    trim_dims[i] = dims[i];
   }
+  framework::DDim actual_dims = framework::make_ddim(trim_dims);
+  return actual_dims;
 }
 
 template <typename T, typename DeviceContext>
@@ -456,6 +462,71 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
 
 #endif
 
+template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
+void ElemwiseGradComputeNoBroadcast(
+    const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
+    const framework::DDim& y_dim, const framework::Tensor& x,
+    const framework::Tensor& y, const framework::Tensor& out,
+    const framework::Tensor& dout, int axis, framework::Tensor* dx,
+    framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
+  size_t N = static_cast<size_t>(framework::product(x_dim));
+  platform::ForRange<DeviceContext> for_range(
+      ctx.template device_context<DeviceContext>(), N);
+  for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
+      x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
+      dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
+      dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
+}
+
+template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
+void ElemwiseGradComputeWithBroadcast(
+    const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
+    const framework::DDim& y_dim_untrimed, const framework::Tensor& x,
+    const framework::Tensor& y, const framework::Tensor& out,
+    const framework::Tensor& dout, int axis, framework::Tensor* dx,
+    framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
+  axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
+  auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
+  axis = (y_dim.size() == 0) ? x_dim.size() : axis;
+
+  int pre, n, post;
+  get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
+  if (post == 1) {
+    int h = pre;
+    int w = n;
+    if (platform::is_gpu_place(ctx.GetPlace())) {
+#ifdef __NVCC__
+      ElemwiseGradBroadcast1CUDA(
+          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
+          y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
+          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
+          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
+#endif
+    } else {
+      ElemwiseGradBroadcast1CPU(
+          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op,
+          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
+          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
+    }
+  } else {
+    if (platform::is_gpu_place(ctx.GetPlace())) {
+#ifdef __NVCC__
+      ElemwiseGradBroadcast2CUDA(
+          ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
+          y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
+          dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
+          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
+#endif
+    } else {
+      ElemwiseGradBroadcast2CPU(
+          x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
+          dx_op, dy_op,
+          dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
+          dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
+    }
+  }
+}
+
 template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
 void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
                          const framework::Tensor& x, const framework::Tensor& y,
@@ -463,63 +534,50 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
                          const framework::Tensor& dout, int axis,
                          framework::Tensor* dx, framework::Tensor* dy,
                          DX_OP dx_op, DY_OP dy_op) {
+  const framework::DDim x_dim = x.dims();
+  const framework::DDim y_dim = y.dims();
   if (x.dims() == y.dims()) {
-    size_t N = static_cast<size_t>(framework::product(x.dims()));
-    platform::ForRange<DeviceContext> for_range(
-        ctx.template device_context<DeviceContext>(), N);
-    for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
-        x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
-        dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
-        dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
+    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
+        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
   } else {  // Y is a scalar
-    auto x_dim = x.dims();
-    auto y_dim = y.dims();
-
-    axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis);
-    trim_trailing_singular_dims(&y_dim);
-    axis = (y_dim.size() == 0) ? x_dim.size() : axis;
-
-    int pre, n, post;
-    get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
-    if (post == 1) {
-      int h = pre;
-      int w = n;
-      if (platform::is_gpu_place(ctx.GetPlace())) {
-#ifdef __NVCC__
-        ElemwiseGradBroadcast1CUDA(
-            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
-            y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op,
-            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
-            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
-#endif
-      } else {
-        ElemwiseGradBroadcast1CPU(
-            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w,
-            dx_op, dy_op,
-            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
-            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
-      }
-    } else {
-      if (platform::is_gpu_place(ctx.GetPlace())) {
-#ifdef __NVCC__
-        ElemwiseGradBroadcast2CUDA(
-            ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
-            y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
-            dy_op,
-            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
-            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
-#endif
-      } else {
-        ElemwiseGradBroadcast2CPU(
-            x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
-            post, dx_op, dy_op,
-            dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
-            dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
-      }
+    ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
+        ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
+  }
+}
+
+// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub.
+// explicit gradient can cut off X, Y, Out from gradient op
+// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
+// elementwise code.
+template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
+void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
+                                 const framework::Tensor& x,
+                                 const framework::Tensor& y,
+                                 const framework::Tensor& out,
+                                 const framework::Tensor& dout, int axis,
+                                 framework::Tensor* dx, framework::Tensor* dy,
+                                 DX_OP dx_op, DY_OP dy_op) {
+  if (dy == nullptr) {
+    const framework::DDim dx_dims = dout.dims();
+    auto dy_dims = dx_dims;
+    ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
+        ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
+  } else {
+    if (dout.dims() == dy->dims()) {
+      const framework::DDim dx_dims = dout.dims();
+      const framework::DDim dy_dims = dy->dims();
+      ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
+          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
+    } else {  // Y is a scalar
+      auto dx_dims = dout.dims();
+      const framework::DDim dy_dims = dy->dims();
+      ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
+          ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
     }
   }
 }
 
+// Deprecated
 template <typename DeviceContext, typename T, typename functor,
           typename broadcastfunctor, typename broadcast2functor>
 void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
@@ -547,7 +605,7 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
   }
 
   axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
-  trim_trailing_singular_dims(&y_dims);
+  trim_trailing_singular_dims(y_dims);
   axis = (y_dims.size() == 0) ? x_dims.size() : axis;
 
   int pre, n, post;
@@ -574,19 +632,19 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
       x, y, z, ctx.template device_context<DeviceContext>(), func);
 
   auto x_dims = x->dims();
-  auto y_dims = y->dims();
-  PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
+  auto y_dims_untrimed = y->dims();
+  PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
                     "Rank of first input must >= rank of second input.");
 
-  if (x_dims == y_dims) {
+  if (x_dims == y_dims_untrimed) {
     functor.Run();
     return;
   }
 
-  axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
+  axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
   PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
                  "Axis should be in range [0, x_dims)");
-  trim_trailing_singular_dims(&y_dims);
+  auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
   axis = (y_dims.size() == 0) ? x_dims.size() : axis;
 
   int pre, n, post;
diff --git a/paddle/fluid/operators/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise_sub_op.cc
index a7562b166b..b7224261e6 100644
--- a/paddle/fluid/operators/elementwise_sub_op.cc
+++ b/paddle/fluid/operators/elementwise_sub_op.cc
@@ -15,7 +15,10 @@ limitations under the License. */
 #include "paddle/fluid/operators/elementwise_sub_op.h"
 #include "paddle/fluid/operators/elementwise_op.h"
 namespace ops = paddle::operators;
-REGISTER_ELEMWISE_OP(elementwise_sub, "Sub", "Out = X - Y");
+REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub);
+REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_sub, "Sub", "Out = X - Y", "Out",
+                              "X");
+
 REGISTER_OP_CPU_KERNEL(
     elementwise_sub,
     ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
diff --git a/paddle/fluid/operators/elementwise_sub_op.h b/paddle/fluid/operators/elementwise_sub_op.h
index fe088b8203..11c7e3fe62 100644
--- a/paddle/fluid/operators/elementwise_sub_op.h
+++ b/paddle/fluid/operators/elementwise_sub_op.h
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
 
-    http://www.apache.org/licenses/LICENSE-2.0
+   http://www.apache.org/licenses/LICENSE-2.0
 
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
@@ -55,14 +55,15 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> {
   void Compute(const framework::ExecutionContext& ctx) const override {
     using Tensor = framework::Tensor;
 
-    auto* x = ctx.Input<Tensor>("X");
-    auto* y = ctx.Input<Tensor>("Y");
-    auto* out = ctx.Input<Tensor>("Out");
     auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
     auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
     auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
     int axis = ctx.Attr<int>("axis");
-    ElemwiseGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
+    // skip out, x, y
+    auto* out = dout;
+    auto *x = dout, *y = dout;
+
+    ElemwiseExplicitGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>(
         ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
   }
 };
diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc
index 31a7458f63..fefc7125b4 100644
--- a/paddle/fluid/operators/softmax_op.cc
+++ b/paddle/fluid/operators/softmax_op.cc
@@ -137,7 +137,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
                       ctx->GetInputDim(framework::GradVarName("Out")),
                       "Input(Out) and its gradients should have a same shape.");
 
-    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
+    ctx->SetOutputDim(framework::GradVarName("X"),
+                      ctx->GetInputDim(framework::GradVarName("Out")));
   }
 
  protected:
@@ -160,8 +161,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
       layout_ = framework::DataLayout::kMKLDNN;
     }
 #endif
-    auto input_data_type =
-        framework::ToDataType(ctx.Input<Tensor>("X")->type());
+    auto input_data_type = framework::ToDataType(
+        ctx.Input<Tensor>(framework::GradVarName("Out"))->type());
     if (input_data_type == framework::proto::VarType::FP16) {
       PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                      "float16 can only be used on GPU place");
@@ -172,13 +173,31 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
   }
 };
 
+class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker {
+ public:
+  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ protected:
+  std::unique_ptr<framework::OpDesc> Apply() const override {
+    auto* op = new framework::OpDesc();
+    op->SetType("softmax_grad");
+
+    op->SetInput("Out", Output("Out"));
+    op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
+
+    op->SetAttrMap(Attrs());
+
+    op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
+    return std::unique_ptr<framework::OpDesc>(op);
+  }
+};
 }  // namespace operators
 }  // namespace paddle
 
 namespace ops = paddle::operators;
 
 REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
-                  paddle::framework::DefaultGradOpDescMaker<true>);
+                  ops::SoftmaxOpGradMaker);
 REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
 REGISTER_OP_CPU_KERNEL(
     softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
index acf652d3fb..1854232194 100644
--- a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
+++ b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
@@ -20,8 +20,8 @@ class TestElementwiseOp(OpTest):
     def setUp(self):
         self.op_type = "elementwise_sub"
         self.inputs = {
-            'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"),
-            'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32")
+            'X': np.random.uniform(0.1, 1, [2, 3]).astype("float32"),
+            'Y': np.random.uniform(0.1, 1, [2, 3]).astype("float32")
         }
         self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}