diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc
index 0e0622e290..2b2a9dc831 100644
--- a/paddle/fluid/operators/cross_entropy_op.cc
+++ b/paddle/fluid/operators/cross_entropy_op.cc
@@ -164,11 +164,13 @@ or not. But the output only shares the LoD information with input X.
 }  // namespace paddle
 
 namespace ops = paddle::operators;
+using CPUCtx = paddle::platform::CPUDeviceContext;
+
 REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
                   paddle::framework::DefaultGradOpDescMaker<true>);
 REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
-REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<float>,
-                       ops::CrossEntropyOpKernel<double>);
+REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
+                       ops::CrossEntropyOpKernel<CPUCtx, double>);
 REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
-                       ops::CrossEntropyGradientOpKernel<float>,
-                       ops::CrossEntropyGradientOpKernel<double>);
+                       ops::CrossEntropyGradientOpKernel<CPUCtx, float>,
+                       ops::CrossEntropyGradientOpKernel<CPUCtx, double>);
diff --git a/paddle/fluid/operators/cross_entropy_op.cu b/paddle/fluid/operators/cross_entropy_op.cu
index 6449149d4b..30dbd5bd3d 100644
--- a/paddle/fluid/operators/cross_entropy_op.cu
+++ b/paddle/fluid/operators/cross_entropy_op.cu
@@ -14,98 +14,11 @@ limitations under the License. */
 
 #include "paddle/fluid/operators/cross_entropy_op.h"
 
-namespace paddle {
-namespace operators {
-
-namespace {
-
-template <typename T>
-__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
-                                           const int64_t* label, const int N,
-                                           const int D) {
-  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
-       i += blockDim.x * gridDim.x) {
-    int idx = i * D + label[i];
-    dX[idx] = -dY[i] / X[idx];
-  }
-}
-
-template <typename T>
-__global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
-                                               const T* label, const int N,
-                                               const int D) {
-  int ids = blockIdx.x * blockDim.x + threadIdx.x;
-  if (ids < N * D) {
-    int row_ids = ids / D;
-    dX[ids] = -label[ids] * dY[row_ids] / X[ids];
-  }
-}
-}  // namespace
-
-template <typename T>
-class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
- public:
-  void Compute(const framework::ExecutionContext& ctx) const override {
-    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
-                   "This kernel only runs on GPU device.");
-    const Tensor* x = ctx.Input<Tensor>("X");
-    const Tensor* label = ctx.Input<Tensor>("Label");
-    Tensor* y = ctx.Output<Tensor>("Y");
-    y->mutable_data<T>(ctx.GetPlace());
-
-    math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
-        ctx.template device_context<platform::CUDADeviceContext>(), y, x, label,
-        ctx.Attr<bool>("soft_label"));
-  }
-};
-
-template <typename T>
-class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
- public:
-  void Compute(const framework::ExecutionContext& ctx) const override {
-    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
-                   "This kernel only runs on GPU device.");
-
-    const Tensor* x = ctx.Input<Tensor>("X");
-    const Tensor* label = ctx.Input<Tensor>("Label");
-    Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
-    dx->mutable_data<T>(ctx.GetPlace());
-
-    const T* dy_data =
-        ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
-    T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
-    const T* x_data = x->data<T>();
-
-    int64_t batch_size = x->dims()[0];
-    int64_t class_num = x->dims()[1];
-
-    int block = 512;
-    int grid = (batch_size * class_num + block - 1) / block;
-
-    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
-    auto stream = dev_ctx.stream();
-
-    if (ctx.Attr<bool>("soft_label")) {
-      auto* label_data = label->data<T>();
-      SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
-          dx_data, dy_data, x_data, label_data, batch_size, class_num);
-    } else {
-      math::SetConstant<platform::CUDADeviceContext, T> functor;
-      functor(dev_ctx, dx, 0);
-      auto* label_data = label->data<int64_t>();
-      grid = (batch_size + block - 1) / block;
-      CrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
-          dx_data, dy_data, x_data, label_data, batch_size, class_num);
-    }
-  }
-};
-
-}  // namespace operators
-}  // namespace paddle
-
 namespace ops = paddle::operators;
-REGISTER_OP_CUDA_KERNEL(cross_entropy, ops::CrossEntropyOpCUDAKernel<float>,
-                        ops::CrossEntropyOpCUDAKernel<double>);
+using CUDACtx = paddle::platform::CUDADeviceContext;
+REGISTER_OP_CUDA_KERNEL(cross_entropy,
+                        ops::CrossEntropyOpKernel<CUDACtx, float>,
+                        ops::CrossEntropyOpKernel<CUDACtx, double>);
 REGISTER_OP_CUDA_KERNEL(cross_entropy_grad,
-                        ops::CrossEntropyGradientOpCUDAKernel<float>,
-                        ops::CrossEntropyGradientOpCUDAKernel<double>);
+                        ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
+                        ops::CrossEntropyGradientOpKernel<CUDACtx, double>);
diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h
index 6da3a24dc8..19a2aec92b 100644
--- a/paddle/fluid/operators/cross_entropy_op.h
+++ b/paddle/fluid/operators/cross_entropy_op.h
@@ -17,69 +17,106 @@ limitations under the License. */
 #include "paddle/fluid/framework/op_registry.h"
 #include "paddle/fluid/operators/math/cross_entropy.h"
 #include "paddle/fluid/operators/math/math_function.h"
+#include "paddle/fluid/platform/for_range.h"
 
 namespace paddle {
 namespace operators {
 
 using Tensor = framework::Tensor;
-template <typename T, int MajorType = Eigen::RowMajor,
-          typename IndexType = Eigen::DenseIndex>
-using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
 
-template <typename T>
+template <typename DeviceContext, typename T>
 class CrossEntropyOpKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
-    PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
-                   "This kernel only runs on CPU.");
-    const Tensor* x = ctx.Input<Tensor>("X");
-    const Tensor* labels = ctx.Input<Tensor>("Label");
-    Tensor* y = ctx.Output<Tensor>("Y");
+    auto* x = ctx.Input<Tensor>("X");
+    auto* labels = ctx.Input<Tensor>("Label");
+    auto* y = ctx.Output<Tensor>("Y");
     y->mutable_data<T>(ctx.GetPlace());
 
-    math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
-        ctx.template device_context<platform::CPUDeviceContext>(), y, x, labels,
+    math::CrossEntropyFunctor<DeviceContext, T>()(
+        ctx.template device_context<DeviceContext>(), y, x, labels,
         ctx.Attr<bool>("soft_label"));
   }
 };
 
 template <typename T>
+class XeSoftlabelGradFunctor {
+ public:
+  XeSoftlabelGradFunctor(T* dx,
+                         const T* dy,     // NOLINT
+                         const T* x,      // NOLINT
+                         const T* label,  // NOLINT
+                         size_t num_classes)
+      : dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
+
+  HOSTDEVICE void operator()(size_t i) {
+    auto row_ids = i / num_classes_;
+    dx_[i] = -label_[i] * dy_[row_ids] / x_[i];
+  }
+
+ private:
+  T* dx_;
+  const T* dy_;
+  const T* x_;
+  const T* label_;
+  size_t num_classes_;
+};
+
+template <typename T>
+class XeGradFunctor {
+ public:
+  XeGradFunctor(T* dx,
+                const T* dy,           // NOLINT
+                const T* x,            // NOLINT
+                const int64_t* label,  // NOLINT
+                size_t num_classes)
+      : dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {}
+
+  HOSTDEVICE void operator()(size_t sample_id) {
+    auto x_is_true_offset = sample_id * num_classes_ + label_[sample_id];
+    for (size_t x_offset = sample_id * num_classes_;
+         x_offset < (sample_id + 1) * num_classes_; ++x_offset) {
+      dx_[x_offset] = x_offset != x_is_true_offset
+                          ? static_cast<T>(0)
+                          : -dy_[sample_id] / x_[x_offset];
+    }
+  }
+
+ private:
+  T* dx_;
+  const T* dy_;
+  const T* x_;
+  const int64_t* label_;
+  size_t num_classes_;
+};
+
+template <typename DeviceContext, typename T>
 class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
-    PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
-                   "This kernel only runs on CPU.");
-    const Tensor* x = ctx.Input<Tensor>("X");
-    const Tensor* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
-    const Tensor* label = ctx.Input<Tensor>("Label");
-    Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
-    T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
+    auto* x = ctx.Input<Tensor>("X");
+    auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
+    auto* label = ctx.Input<Tensor>("Label");
+    auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
+    auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
 
     int64_t class_num = x->dims()[1];
     if (ctx.Attr<bool>("soft_label")) {
-      auto x_mat = EigenMatrix<T>::From(*x);
-      auto dy_mat = EigenMatrix<T>::From(*dy);
-      auto lbl_mat = EigenMatrix<T>::From(*label);
-      auto dx_mat = EigenMatrix<T>::From(*dx);
-
-      dx_mat.device(*ctx.template device_context<platform::CPUDeviceContext>()
-                         .eigen_device()) =
-          -(lbl_mat *
-            dy_mat.broadcast(Eigen::DSizes<int64_t, 2>(1, class_num)) / x_mat);
+      XeSoftlabelGradFunctor<T> functor(dx_data, dy->data<T>(), x->data<T>(),
+                                        label->data<T>(),
+                                        static_cast<size_t>(class_num));
+      platform::ForRange<DeviceContext> for_range(
+          ctx.template device_context<DeviceContext>(),
+          static_cast<size_t>(dx->numel()));
+      for_range(functor);
     } else {
-      int64_t batch_size = x->dims()[0];
-      const T* dy_data = dy->data<T>();
-      const T* x_data = x->data<T>();
-      const int64_t* label_data = label->data<int64_t>();
-
-      math::SetConstant<platform::CPUDeviceContext, T> functor;
-      functor(ctx.template device_context<platform::CPUDeviceContext>(), dx, 0);
-
-      for (int64_t i = 0; i < batch_size; ++i) {
-        PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
-        int64_t index = i * class_num + label_data[i];
-        dx_data[index] = math::TolerableValue<T>()(-dy_data[i] / x_data[index]);
-      }
+      XeGradFunctor<T> functor(dx_data, dy->data<T>(), x->data<T>(),
+                               label->data<int64_t>(),
+                               static_cast<size_t>(class_num));
+      platform::ForRange<DeviceContext> for_range(
+          ctx.template device_context<DeviceContext>(),
+          static_cast<size_t>(dy->numel()));
+      for_range(functor);
     }
   }
 };