diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index 2116635493..87efb900cd 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -101,8 +101,8 @@ set(DEPS_OPS
 op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
   DEPS framework_proto tensor net_op)
 op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
-op_library(cross_entropy_op DEPS cross_entropy_function)
-op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function)
+op_library(cross_entropy_op DEPS cross_entropy)
+op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
 
 list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
 foreach(src ${GENERAL_OPS})
diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt
index 6bea9817f1..b39d4f0ac2 100644
--- a/paddle/operators/math/CMakeLists.txt
+++ b/paddle/operators/math/CMakeLists.txt
@@ -1,17 +1,15 @@
 if(WITH_GPU)
     nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
-      im2col.cu DEPS cblas device_context operator)
+               im2col.cu DEPS cblas device_context operator)
     nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
-    nv_library(softmax_function SRCS softmax.cc softmax.cu
-      DEPS operator)
-    nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu
-      DEPS operator)
+    nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
+    nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
 else()
     cc_library(math_function SRCS math_function.cc im2col.cc
-      DEPS cblas device_context operator)
+               DEPS cblas device_context operator)
     cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
-    cc_library(softmax_function SRCS softmax.cc DEPS operator)
-    cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator)
+    cc_library(softmax SRCS softmax.cc DEPS operator)
+    cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
 endif()
 
 cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc
index ac9f3c4bf6..0ba8197ab8 100644
--- a/paddle/operators/math/softmax.cc
+++ b/paddle/operators/math/softmax.cc
@@ -1,16 +1,16 @@
 /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
 
-   http://www.apache.org/licenses/LICENSE-2.0
+    http://www.apache.org/licenses/LICENSE-2.0
 
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
 
 #include "paddle/operators/math/softmax.h"
 
@@ -19,6 +19,7 @@ namespace operators {
 namespace math {
 
 template class SoftmaxFunctor<platform::CPUPlace, float>;
+template class SoftmaxGradFunctor<platform::CPUPlace, float>;
 
 }  // namespace math
 }  // namespace operators
diff --git a/paddle/operators/math/softmax.cu b/paddle/operators/math/softmax.cu
index 4c3df0550e..99f988d51e 100644
--- a/paddle/operators/math/softmax.cu
+++ b/paddle/operators/math/softmax.cu
@@ -1,16 +1,16 @@
 /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
 
-   http://www.apache.org/licenses/LICENSE-2.0
+    http://www.apache.org/licenses/LICENSE-2.0
 
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
 
 #define EIGEN_USE_GPU
 
@@ -21,6 +21,7 @@ namespace operators {
 namespace math {
 
 template class SoftmaxFunctor<platform::GPUPlace, float>;
+template class SoftmaxGradFunctor<platform::GPUPlace, float>;
 
 }  // namespace math
 }  // namespace operators
diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h
index 225323f05a..b7f627eee7 100644
--- a/paddle/operators/math/softmax.h
+++ b/paddle/operators/math/softmax.h
@@ -1,16 +1,16 @@
 /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
 
-   http://www.apache.org/licenses/LICENSE-2.0
+    http://www.apache.org/licenses/LICENSE-2.0
 
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
 
 #pragma once
 #include "paddle/framework/eigen.h"
@@ -68,6 +68,37 @@ class SoftmaxFunctor {
              .broadcast(one_by_class));
   }
 };
+
+template <typename Place, typename T>
+class SoftmaxGradFunctor {
+ public:
+  void operator()(const platform::DeviceContext& context,
+                  const framework::Tensor* y, const framework::Tensor* y_grad,
+                  framework::Tensor* x_grad) {
+    auto softmax = EigenMatrix<T>::From(*y);
+    auto softmax_grad = EigenMatrix<T>::From(*y_grad);
+    auto logits_grad = EigenMatrix<T>::From(*x_grad);
+
+    const int kBatchDim = 0;
+    const int kClassDim = 1;
+
+    const int batch_size = softmax.dimension(kBatchDim);
+    const int num_classes = softmax.dimension(kClassDim);
+
+    Eigen::DSizes<int, 1> along_class(kClassDim);
+    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
+    Eigen::DSizes<int, 2> one_by_class(1, num_classes);
+
+    auto dot = (softmax * softmax_grad)
+                   .sum(along_class)
+                   .eval()
+                   .reshape(batch_by_one)
+                   .broadcast(one_by_class);
+    logits_grad.device(*context.GetEigenDevice<Place>()) =
+        (softmax_grad - dot) * softmax;
+  }
+};
+
 }  // namespace math
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc
index 9858c4d9c2..3c8fe04d2e 100644
--- a/paddle/operators/mul_op.cc
+++ b/paddle/operators/mul_op.cc
@@ -1,16 +1,16 @@
 /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
 
-   http://www.apache.org/licenses/LICENSE-2.0
+    http://www.apache.org/licenses/LICENSE-2.0
 
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
 
 #include "paddle/operators/mul_op.h"
 
@@ -35,12 +35,14 @@ class MulOp : public framework::OperatorWithKernel {
     int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
     int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
 
-    PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
-                   "The rank of input tensor X should be larger than "
-                   "`mul_op`'s `x_num_col_dims`.");
-    PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
-                   "The rank of input tensor Y should be larger than "
-                   "`mul_op`'s `y_num_col_dims`.");
+    PADDLE_ENFORCE_GT(
+        x_dims.size(), x_num_col_dims,
+        "The input tensor X's rank of MulOp should be larger than "
+        "x_num_col_dims.");
+    PADDLE_ENFORCE_GT(
+        y_dims.size(), y_num_col_dims,
+        "The input tensor Y's rank of MulOp should be larger than "
+        "y_num_col_dims.");
 
     auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
     auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc
index 17685ea654..bc4af2f704 100644
--- a/paddle/operators/sequence_pool_op.cc
+++ b/paddle/operators/sequence_pool_op.cc
@@ -24,9 +24,9 @@ class SequencePoolOp : public framework::OperatorWithKernel {
  protected:
   void InferShape(framework::InferShapeContextBase* ctx) const override {
     PADDLE_ENFORCE(ctx->HasInput("X"),
-                   "Input(X) of SequenceAvgPoolOp should not be null.");
+                   "Input(X) of SequencePoolOp should not be null.");
     PADDLE_ENFORCE(ctx->HasOutput("Out"),
-                   "Output(Out) of SequenceAvgPoolOp should not be null.");
+                   "Output(Out) of SequencePoolOp should not be null.");
     ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
   }
 };
diff --git a/paddle/operators/sequence_softmax_op.cc b/paddle/operators/sequence_softmax_op.cc
new file mode 100644
index 0000000000..621779ab61
--- /dev/null
+++ b/paddle/operators/sequence_softmax_op.cc
@@ -0,0 +1,103 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/sequence_softmax_op.h"
+
+namespace paddle {
+namespace operators {
+
+class SequenceSoftmaxOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(framework::InferShapeContextBase* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("X"),
+                   "Input(X) of SequenceSoftmaxOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasOutput("Out"),
+                   "Output(Out) of SequenceSoftmaxOp should not be null.");
+    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
+    ctx->ShareLoD("X", /*->*/ "Out");
+  }
+};
+
+class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  SequenceSoftmaxOpMaker(framework::OpProto* proto,
+                         framework::OpAttrChecker* op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("X",
+             "(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
+             "of length 1.");
+    AddOutput("Out",
+              "(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension "
+              "of length 1.");
+    AddComment(R"DOC(
+SequenceSoftmaxOp computes softmax activation among all time-steps for each
+sequence. The dimension of each time-step should be 1. Thus, the shape of
+input Tensor can be either [N, 1] or [N], where N is the sum of all sequences'
+lengths.
+
+Equation:
+    for i-th sequence in a mini-batch:
+        Out(X[lod[i]:lod[i+1]], :) =
+            exp(X[lod[i]:lod[i+1], :]) / sum(exp(X[lod[i]:lod[i+1], :]))
+
+For example, for a mini-batch of 3 sequences with variable-length,
+each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7],
+then softmax will be computed among X[0:2, :], X[2:5, :], X[5:7, :]
+and N turns out to be 7.
+)DOC");
+  }
+};
+
+class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(framework::InferShapeContextBase* ctx) const override {
+    PADDLE_ENFORCE(ctx->HasInput("Out"),
+                   "Input(Out) of SequenceSoftmaxGradOp should not be null.");
+    PADDLE_ENFORCE(
+        ctx->HasInput(framework::GradVarName("Out")),
+        "Input(Out@GRAD) of SequenceSoftmaxGradOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasInput("X"),
+                   "Input(X) of SequenceSoftmaxOp should not be null.");
+    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
+                   "Output(X@GRAD) of SequenceSoftmaxOp should not be null.");
+
+    PADDLE_ENFORCE_EQ(
+        ctx->GetInputDim("Out"),
+        ctx->GetInputDim(framework::GradVarName("Out")),
+        "Input(Out) and Input(Out@GRAD) of SequenceSoftmaxGradOp should be of "
+        "the same shape.");
+
+    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(sequence_softmax, ops::SequenceSoftmaxOp,
+            ops::SequenceSoftmaxOpMaker, sequence_softmax_grad,
+            ops::SequenceSoftmaxGradOp);
+REGISTER_OP_CPU_KERNEL(
+    sequence_softmax,
+    ops::SequenceSoftmaxKernel<paddle::platform::CPUPlace, float>);
+REGISTER_OP_CPU_KERNEL(
+    sequence_softmax_grad,
+    ops::SequenceSoftmaxGradKernel<paddle::platform::CPUPlace, float>);
diff --git a/paddle/operators/sequence_softmax_op.cu b/paddle/operators/sequence_softmax_op.cu
new file mode 100644
index 0000000000..f2a1e3d5e3
--- /dev/null
+++ b/paddle/operators/sequence_softmax_op.cu
@@ -0,0 +1,25 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+
+#include "paddle/operators/sequence_softmax_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(
+    sequence_softmax,
+    ops::SequenceSoftmaxKernel<paddle::platform::GPUPlace, float>)
+REGISTER_OP_GPU_KERNEL(
+    sequence_softmax_grad,
+    ops::SequenceSoftmaxGradKernel<paddle::platform::GPUPlace, float>);
diff --git a/paddle/operators/sequence_softmax_op.h b/paddle/operators/sequence_softmax_op.h
new file mode 100644
index 0000000000..96d87c404d
--- /dev/null
+++ b/paddle/operators/sequence_softmax_op.h
@@ -0,0 +1,94 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/operators/math/softmax.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+using LoDTensor = framework::LoDTensor;
+
+template <typename Place, typename T>
+class SequenceSoftmaxKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* x = ctx.Input<LoDTensor>("X");
+    auto* out = ctx.Output<LoDTensor>("Out");
+
+    auto lod = x->lod();
+    auto dims = x->dims();
+
+    const size_t level = lod.size() - 1;
+    PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
+                      "The first dimension of Input(X) should be equal to the "
+                      "sum of all sequences' lengths.");
+    PADDLE_ENFORCE_EQ(dims[0], x->numel(),
+                      "The width of each timestep in Input(X) of "
+                      "SequenceSoftmaxOp should be 1.");
+
+    out->mutable_data<T>(ctx.GetPlace());
+    for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
+      int start_pos = static_cast<int>(lod[level][i]);
+      int end_pos = static_cast<int>(lod[level][i + 1]);
+      Tensor x_i = x->Slice<T>(start_pos, end_pos);
+      Tensor out_i = out->Slice<T>(start_pos, end_pos);
+
+      // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
+      framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
+      x_i.Resize(dims_i);
+      out_i.Resize(dims_i);
+      math::SoftmaxFunctor<Place, T>()(ctx.device_context(), &x_i, &out_i);
+    }
+  }
+};
+
+template <typename Place, typename T>
+class SequenceSoftmaxGradKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    auto* out = ctx.Input<LoDTensor>("Out");
+    auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
+    auto* x = ctx.Input<LoDTensor>("X");
+    auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
+
+    auto lod = x->lod();
+    const size_t level = lod.size() - 1;
+
+    x_grad->mutable_data<T>(ctx.GetPlace());
+    for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
+      int start_pos = static_cast<int>(lod[level][i]);
+      int end_pos = static_cast<int>(lod[level][i + 1]);
+
+      Tensor out_i = out->Slice<T>(start_pos, end_pos);
+      Tensor out_grad_i = out_grad->Slice<T>(start_pos, end_pos);
+      Tensor x_grad_i = x_grad->Slice<T>(start_pos, end_pos);
+
+      // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
+      framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
+      out_i.Resize(dims_i);
+      out_grad_i.Resize(dims_i);
+      x_grad_i.Resize(dims_i);
+      math::SoftmaxGradFunctor<Place, T>()(ctx.device_context(), &out_i,
+                                           &out_grad_i, &x_grad_i);
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h
index 8fdda8b1df..2c08853f4f 100644
--- a/paddle/operators/softmax_op.h
+++ b/paddle/operators/softmax_op.h
@@ -29,8 +29,8 @@ template <typename Place, typename T>
 class SoftmaxKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
-    auto X = context.Input<Tensor>("X");
-    auto Y = context.Output<Tensor>("Y");
+    auto* X = context.Input<Tensor>("X");
+    auto* Y = context.Output<Tensor>("Y");
 
     // allocate memory on device.
     Y->mutable_data<T>(context.GetPlace());
@@ -43,29 +43,14 @@ template <typename Place, typename T>
 class SoftmaxGradKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
-    auto Y = context.Input<Tensor>("Y");
-    auto dY = context.Input<Tensor>(framework::GradVarName("Y"));
-    auto dX = context.Output<Tensor>(framework::GradVarName("X"));
-    dX->mutable_data<T>(context.GetPlace());
-
-    const int batch_size = Y->dims()[0];
-    const int class_num = Y->dims()[1];
-
-    Eigen::DSizes<int, 1> along_class(1);
-    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
-    Eigen::DSizes<int, 2> one_by_class(1, class_num);
+    auto* Y = context.Input<Tensor>("Y");
+    auto* dY = context.Input<Tensor>(framework::GradVarName("Y"));
+    auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
 
-    auto Y_eigen = EigenMatrix<T>::From(*Y);
-    auto dY_eigen = EigenMatrix<T>::From(*dY);
-    auto dX_eigen = EigenMatrix<T>::From(*dX);
-    auto place = context.GetEigenDevice<Place>();
+    // allocate memory on device.
+    dX->mutable_data<T>(context.GetPlace());
 
-    auto dot = (Y_eigen * dY_eigen)
-                   .sum(along_class)
-                   .eval()
-                   .reshape(batch_by_one)
-                   .broadcast(one_by_class);
-    dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen;
+    math::SoftmaxGradFunctor<Place, T>()(context.device_context(), Y, dY, dX);
   }
 };
 
diff --git a/python/paddle/v2/framework/tests/test_sequence_softmax_op.py b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py
new file mode 100644
index 0000000000..b54a56aa6d
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py
@@ -0,0 +1,38 @@
+import unittest
+import numpy as np
+from op_test import OpTest
+
+
+def stable_softmax(x):
+    """Compute the softmax of vector x in a numerically stable way."""
+    shiftx = x - np.max(x).clip(-64.)
+    exps = np.exp(shiftx)
+    return exps / np.sum(exps)
+
+
+class TestSequenceSoftmaxOp(OpTest):
+    def setUp(self):
+        self.op_type = "sequence_softmax"
+        x = np.random.uniform(0.1, 1, (11, 1)).astype("float32")
+        lod = [[0, 4, 5, 8, 11]]
+
+        out = np.zeros((11, 1)).astype("float32")
+        for i in range(4):
+            sub_x = x[lod[0][i]:lod[0][i + 1], :]
+            sub_x = sub_x.reshape(1, lod[0][i + 1] - lod[0][i])
+            sub_out = stable_softmax(sub_x)
+            out[lod[0][i]:lod[0][i + 1], :] = sub_out.reshape(
+                lod[0][i + 1] - lod[0][i], 1)
+
+        self.inputs = {"X": (x, lod)}
+        self.outputs = {"Out": out}
+
+    def test_check_output(self):
+        self.check_output()
+
+    def test_check_grad(self):
+        self.check_grad(["X"], "Out", max_relative_error=0.01)
+
+
+if __name__ == "__main__":
+    unittest.main()