From 6cb66721d2e98d9f8f6b15478ba4796f14eecab0 Mon Sep 17 00:00:00 2001
From: dengkaipeng <dengkaipeng@baidu.com>
Date: Mon, 4 Mar 2019 15:23:35 +0000
Subject: [PATCH] add cudnn support. test=develop

---
 paddle/fluid/operators/softmax_cudnn_op.cu.cc | 70 ++++++++++++----
 paddle/fluid/operators/softmax_op.h           | 83 ++++++++++++-------
 .../fluid/tests/unittests/test_softmax_op.py  | 61 +++++++++++++-
 3 files changed, 164 insertions(+), 50 deletions(-)

diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc
index ad3e5543f1..84151d70b9 100644
--- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc
+++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #include "paddle/fluid/operators/math/softmax.h"
+#include "paddle/fluid/operators/softmax_op.h"
 #include "paddle/fluid/framework/op_registry.h"
 
 namespace paddle {
@@ -24,22 +25,40 @@ template <typename T>
 class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
+    auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>();
     auto* X = context.Input<Tensor>("X");
     auto* Out = context.Output<Tensor>("Out");
+    // auto dims = X->dims();
+    const int axis = context.Attr<int>("axis");
+    int rank = X->dims().size();
 
     // allocate memory on device.
     Out->mutable_data<T>(context.GetPlace());
 
-    auto dims = X->dims();
-    auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
-    framework::LoDTensor flattened_x;
-    framework::LoDTensor flattened_out;
-    flattened_x.ShareDataWith(*X).Resize(flattened_dims);
-    flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
+    std::vector<int> perm, shape;
+    CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape);
+
+    Tensor X_2d, Out_2d;
+    Tensor X_trans, Out_trans;
+    if (axis != -1 && axis != rank - 1) {
+      X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
+      TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
+      X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
+      Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
+    } else {
+      X_2d = framework::ReshapeToMatrix(*X, rank - 1);
+      Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
+    }
 
     math::SoftmaxCUDNNFunctor<T>()(
         context.template device_context<platform::CUDADeviceContext>(),
-        &flattened_x, &flattened_out);
+        &X_2d, &Out_2d);
+
+    if (axis != -1 && axis != rank - 1) {
+      TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
+    }
   }
 };
 
@@ -47,25 +66,44 @@ template <typename T>
 class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
+    auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>();
     auto* Out = context.Input<Tensor>("Out");
     auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
     auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
+    const int axis = context.Attr<int>("axis");
+    int rank = Out->dims().size();
 
     // allocate memory on device.
     dX->mutable_data<T>(context.GetPlace());
 
-    auto dims = Out->dims();
-    auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
-    framework::LoDTensor flattened_out;
-    framework::LoDTensor flattened_d_out;
-    framework::LoDTensor flattened_d_x;
-    flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
-    flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
-    flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
+    std::vector<int> perm, shape;
+    CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape);
+
+    Tensor dX_2d, Out_2d, dOut_2d;
+    Tensor dX_trans, Out_trans, dOut_trans;
+    if (axis != -1 && axis != rank - 1) {
+      dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
+      TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
+      TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
+      dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
+      Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
+      dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
+    } else {
+      dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
+      Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
+      dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
+    }
 
     math::SoftmaxGradCUDNNFunctor<T>()(
         context.template device_context<platform::CUDADeviceContext>(),
-        &flattened_out, &flattened_d_out, &flattened_d_x);
+        &Out_2d, &dOut_2d, &dX_2d);
+
+    if (axis != -1 && axis != rank - 1) {
+      TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
+    }
   }
 };
 
diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h
index ad41e52116..1810b23e0d 100644
--- a/paddle/fluid/operators/softmax_op.h
+++ b/paddle/fluid/operators/softmax_op.h
@@ -23,59 +23,58 @@ namespace operators {
 
 using Tensor = framework::Tensor;
 
-template <typename DeviceContext, typename T>
-static inline void TransposeAxisToEnd(const Tensor& x, const Tensor& out,
-                                      Tensor* x_trans, Tensor* out_trans,
-                                      const int axis, std::vector<int> perm,
-                                      const framework::ExecutionContext& ctx) {
+static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis,
+                                std::vector<int>* perm, std::vector<int>* shape) {
   auto dim_x = x.dims();
   int rank = dim_x.size();
 
   if (axis == -1 || axis == rank - 1) {
-    *x_trans = x;
-    *out_trans = out;
     return;
   }
 
-  auto& dev_ctx = ctx.template device_context<DeviceContext>();
-  std::vector<int> shape;
   for (int i = 0; i < rank - 1; i++) {
     if (i == axis) {
-      perm.push_back(rank - 1);
-      shape.push_back(dim_x[rank - 1]);
+      perm->push_back(rank - 1);
+      shape->push_back(dim_x[rank - 1]);
     } else {
-      perm.push_back(i);
-      shape.push_back(dim_x[i]);
+      perm->push_back(i);
+      shape->push_back(dim_x[i]);
     }
   }
-  perm.push_back(axis);
-  shape.push_back(dim_x[axis]);
-
-  x_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
-  out_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
-  TransCompute<DeviceContext, T>(rank, dev_ctx, x, x_trans, perm);
-  TransCompute<DeviceContext, T>(rank, dev_ctx, out, out_trans, perm);
+  perm->push_back(axis);
+  shape->push_back(dim_x[axis]);
 }
 
 template <typename DeviceContext, typename T>
 class SoftmaxKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
+    auto& dev_ctx = context.template device_context<DeviceContext>();
     auto* X = context.Input<Tensor>("X");
     auto* Out = context.Output<Tensor>("Out");
     const int axis = context.Attr<int>("axis");
+    int rank = X->dims().size();
 
     // allocate memory on device.
     Out->mutable_data<T>(context.GetPlace());
 
+    std::vector<int> perm, shape;
+    CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape);
+
+    Tensor X_2d, Out_2d;
     Tensor X_trans, Out_trans;
-    std::vector<int> perm;
-    TransposeAxisToEnd<DeviceContext, T>(*X, *Out, &X_trans, &Out_trans, axis,
-                                         perm, context);
+    if (axis != -1 && axis != rank - 1) {
+      X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      TransCompute<DeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
+      TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
+      X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
+      Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
+    } else {
+      X_2d = framework::ReshapeToMatrix(*X, rank - 1);
+      Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
+    }
 
-    int rank = X->dims().size();
-    Tensor X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
-    Tensor Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
 
 #ifdef PADDLE_ON_INFERENCE
     math::SoftmaxFunctor<DeviceContext, T, true>()(
@@ -86,7 +85,6 @@ class SoftmaxKernel : public framework::OpKernel<T> {
 #endif
 
     if (axis != -1 && axis != rank - 1) {
-      auto& dev_ctx = context.template device_context<DeviceContext>();
       TransCompute<DeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
     }
   }
@@ -96,21 +94,44 @@ template <typename DeviceContext, typename T>
 class SoftmaxGradKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
+    auto& dev_ctx = context.template device_context<DeviceContext>();
     auto* Out = context.Input<Tensor>("Out");
     auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
     auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
+    const int axis = context.Attr<int>("axis");
+    int rank = Out->dims().size();
 
     // allocate memory on device.
     dX->mutable_data<T>(context.GetPlace());
 
-    int rank = Out->dims().size();
-    Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
-    Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
-    Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
+    std::vector<int> perm, shape;
+    CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape);
+
+    Tensor dX_2d, Out_2d, dOut_2d;
+    Tensor dX_trans, Out_trans, dOut_trans;
+    if (axis != -1 && axis != rank - 1) {
+      dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
+      TransCompute<DeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
+      TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
+      TransCompute<DeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
+      dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
+      Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
+      dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
+    } else {
+      dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
+      Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
+      dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
+    }
 
     math::SoftmaxGradFunctor<DeviceContext, T>()(
         context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
         &dX_2d);
+
+    if (axis != -1 && axis != rank - 1) {
+      TransCompute<DeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
+    }
   }
 };
 
diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py
index 5c56de6779..084fa869e3 100644
--- a/python/paddle/fluid/tests/unittests/test_softmax_op.py
+++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py
@@ -31,6 +31,9 @@ class TestSoftmaxOp(OpTest):
     def get_x_shape(self):
         return [10, 10]
 
+    def get_axis(self):
+        return -1
+
     def setUp(self):
         self.op_type = "softmax"
         self.use_cudnn = False
@@ -38,15 +41,15 @@ class TestSoftmaxOp(OpTest):
         self.dtype = np.float32
         self.init_kernel_type()
         self.shape = self.get_x_shape()
+        self.axis = self.get_axis()
 
         x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
-        out = np.apply_along_axis(stable_softmax, 1,
-                                  x.reshape([-1, self.shape[-1]]))
-        out = out.reshape(self.shape)
+        out = np.apply_along_axis(stable_softmax, self.axis, x)
 
         self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
         self.outputs = {'Out': out}
         self.attrs = {
+            'axis': self.axis,
             'use_cudnn': self.use_cudnn,
             'use_mkldnn': self.use_mkldnn
         }
@@ -76,6 +79,38 @@ class TestSoftmaxOp2(TestSoftmaxOp):
         return [2, 3, 4, 5]
 
 
+class TestSoftmaxOp3(TestSoftmaxOp):
+    def get_x_shape(self):
+        return [2, 3, 4, 5]
+
+    def get_axis(self):
+        return 0
+
+
+class TestSoftmaxOp4(TestSoftmaxOp):
+    def get_x_shape(self):
+        return [2, 3, 4, 5]
+
+    def get_axis(self):
+        return 1
+
+
+class TestSoftmaxOp5(TestSoftmaxOp):
+    def get_x_shape(self):
+        return [2, 3, 4, 5]
+
+    def get_axis(self):
+        return 2
+
+
+class TestSoftmaxOp5(TestSoftmaxOp):
+    def get_x_shape(self):
+        return [2, 3, 4, 5]
+
+    def get_axis(self):
+        return 3
+
+
 @unittest.skipIf(not core.is_compiled_with_cuda(),
                  "core is not compiled with CUDA")
 class TestSoftmaxCUDNNOp(TestSoftmaxOp):
@@ -90,6 +125,26 @@ class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp):
         return [2, 3, 4, 5]
 
 
+@unittest.skipIf(not core.is_compiled_with_cuda(),
+                 "core is not compiled with CUDA")
+class TestSoftmaxCUDNNOp3(TestSoftmaxCUDNNOp):
+    def get_x_shape(self):
+        return [2, 3, 4, 5]
+
+    def get_axis(self):
+        return 1
+
+
+@unittest.skipIf(not core.is_compiled_with_cuda(),
+                 "core is not compiled with CUDA")
+class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp):
+    def get_x_shape(self):
+        return [2, 3, 4, 5]
+
+    def get_axis(self):
+        return 2
+
+
 @unittest.skipIf(not core.is_compiled_with_cuda(),
                  "core is not compiled with CUDA")
 class TestSoftmaxFP16Op(TestSoftmaxOp):