diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc
index a9950a48e0..60ad2efbe9 100644
--- a/paddle/operators/dropout_op.cc
+++ b/paddle/operators/dropout_op.cc
@@ -37,6 +37,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
   DropoutOpMaker(framework::OpProto *proto,
                  framework::OpAttrChecker *op_checker)
       : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddAttr<float>("dropout_prob", "Dropout probability.").SetDefault(.5f);
+    AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
     AddInput("X", "The input of dropout op.");
     AddOutput("Out", "The output of dropout op.");
     AddOutput("Mask", "The dropout mask.").AsIntermediate();
@@ -75,7 +77,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
 namespace ops = paddle::operators;
 REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad,
             ops::DropoutOpGrad);
-REGISTER_OP_CPU_KERNEL(dropout,
-                       ops::DropoutKernel<paddle::platform::CPUPlace, float>);
+REGISTER_OP_CPU_KERNEL(
+    dropout, ops::CPUDropoutKernel<paddle::platform::CPUPlace, float>);
 REGISTER_OP_CPU_KERNEL(
     dropout_grad, ops::DropoutGradKernel<paddle::platform::CPUPlace, float>);
diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu
index 9e9efaa3b1..c869ddf3e5 100644
--- a/paddle/operators/dropout_op.cu
+++ b/paddle/operators/dropout_op.cu
@@ -16,7 +16,7 @@
 #include "paddle/operators/dropout_op.h"
 
 namespace ops = paddle::operators;
-REGISTER_OP_GPU_KERNEL(dropout,
-                       ops::DropoutKernel<paddle::platform::GPUPlace, float>);
+REGISTER_OP_GPU_KERNEL(
+    dropout, ops::GPUDropoutKernel<paddle::platform::GPUPlace, float>);
 REGISTER_OP_GPU_KERNEL(
     dropout_grad, ops::DropoutGradKernel<paddle::platform::GPUPlace, float>);
diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h
index d5d32df74b..becf89aca3 100644
--- a/paddle/operators/dropout_op.h
+++ b/paddle/operators/dropout_op.h
@@ -13,6 +13,11 @@
    limitations under the License. */
 
 #pragma once
+#include <thrust/device_ptr.h>
+#include <thrust/iterator/counting_iterator.h>
+#include <thrust/random.h>
+#include <thrust/transform.h>
+#include <random>
 #include "paddle/framework/eigen.h"
 #include "paddle/framework/op_registry.h"
 
@@ -25,25 +30,85 @@ template <typename T, int MajorType = Eigen::RowMajor,
 using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
 
 template <typename Place, typename T>
-class DropoutKernel : public framework::OpKernel {
+class CPUDropoutKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto* x = context.Input<Tensor>("X");
+    auto* y = context.Output<Tensor>("Out");
+    auto* mask = context.Output<Tensor>("Mask");
+    T* mask_data = mask->mutable_data<T>(context.GetPlace());
+    T* y_data = y->mutable_data<T>(context.GetPlace());
+    const T* x_data = x->data<T>();
+
+    float dropout_prob = context.op_.GetAttr<float>("dropout_prob");
+    int seed = context.op_.GetAttr<int>("seed");
+
+    std::minstd_rand engine;
+    engine.seed(seed);
+    std::uniform_real_distribution<T> dist(0, 1);
+    size_t size = framework::product(mask->dims());
+    for (size_t i = 0; i < size; ++i) {
+      if (dist(engine) < dropout_prob) {
+        mask_data[i] = 0;
+        y_data[i] = 0;
+      } else {
+        mask_data[i] = 1;
+        y_data[i] = (1 - dropout_prob) * x_data[i];
+      }
+    }
+  }
+};
+
+template <typename T>
+struct MaskGenerator {
+  float dropout_prob_;
+  int seed_;
+
+  __host__ __device__ MaskGenerator(float dropout_prob, int seed)
+      : dropout_prob_(dropout_prob), seed_(seed) {}
+
+  __host__ __device__ T operator()(const unsigned int n) const {
+    thrust::minstd_rand rng;
+    rng.seed(seed_);
+    thrust::uniform_real_distribution<T> dist(0, 1);
+    rng.discard(n);
+    if (dist(rng) < dropout_prob_) {
+      return static_cast<T>(0);
+    } else {
+      return static_cast<T>(1);
+    }
+  }
+};
+
+// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
+// Use std::random and thrust::random(thrust is a std library in CUDA) to
+// implement uniform random.
+template <typename Place, typename T>
+class GPUDropoutKernel : public framework::OpKernel {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
     auto* x = context.Input<Tensor>("X");
     auto* y = context.Output<Tensor>("Out");
     auto* mask = context.Output<Tensor>("Mask");
-    mask->mutable_data<T>(context.GetPlace());
     y->mutable_data<T>(context.GetPlace());
 
+    float dropout_prob = context.op_.GetAttr<float>("dropout_prob");
+    int seed = context.op_.GetAttr<int>("seed");
+    thrust::counting_iterator<unsigned int> index_sequence_begin(0);
+    int size = framework::product(mask->dims());
+    T* mask_data = mask->mutable_data<T>(context.GetPlace());
+    thrust::transform(index_sequence_begin, index_sequence_begin + size,
+                      thrust::device_ptr<T>(mask_data),
+                      MaskGenerator<T>(dropout_prob, seed));
+
     auto dims = x->dims();
-    auto X = EigenMatrix<T>::From(*x);
-    auto Y = EigenMatrix<T>::From(*y);
-    auto M = EigenMatrix<T>::From(*mask);
+    auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
+    auto X = EigenMatrix<T>::From(*x, new_dims);
+    auto Y = EigenMatrix<T>::From(*y, new_dims);
+    auto M = EigenMatrix<T>::From(*mask, new_dims);
 
     auto place = context.GetEigenDevice<Place>();
-    M.device(place).setRandom<UniformRandomGenerator>();
-    float dropout_prob = context.op_.GetAttr<float>("dropout_prob");
-    M.device(place) = (M > dropout_prob).cast<float>();
-    Y.device(place) = X * Y;
+    Y.device(place) = X * M * (1 - dropout_prob);
   }
 };
 
@@ -57,12 +122,15 @@ class DropoutGradKernel : public framework::OpKernel {
     grad_x->mutable_data<T>(context.GetPlace());
 
     auto dims = grad_x->dims();
-    auto M = EigenMatrix<T>::From(*mask);
-    auto dX = EigenMatrix<T>::From(*grad_x);
-    auto dY = EigenMatrix<T>::From(*grad_y);
+    int size = static_cast<int>(framework::product(dims));
+    auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
+    auto M = EigenMatrix<T>::From(*mask, new_dims);
+    auto dX = EigenMatrix<T>::From(*grad_x, new_dims);
+    auto dY = EigenMatrix<T>::From(*grad_y, new_dims);
 
     auto place = context.GetEigenDevice<Place>();
-    dX.device(place) = dY * M;
+    float dropout_prob = context.op_.GetAttr<float>("dropout_prob");
+    dX.device(place) = dY * M * (1 - dropout_prob);
   }
 };
 
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index 661ebd8964..850910363d 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -4,6 +4,7 @@ py_test(test_scope SRCS test_scope.py)
 
 py_test(test_tensor SRCS test_tensor.py)
 py_test(test_mul_op SRCS test_mul_op.py)
+py_test(test_dropout_op SRCS test_dropout_op.py)
 
 py_test(test_mean_op SRCS test_mean_op.py)
 
diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py
index 3bc05a0fec..a4899355b5 100644
--- a/python/paddle/v2/framework/tests/op_test_util.py
+++ b/python/paddle/v2/framework/tests/op_test_util.py
@@ -6,13 +6,13 @@ from paddle.v2.framework.op import Operator
 class OpTestMeta(type):
     """
     Operator Test ClassMeta.
-    
-    It injects `test_all` method into user's OperatorTest class, to make Python 
+
+    It injects `test_all` method into user's OperatorTest class, to make Python
     unittest module run that method.
-    
+
     The `test_all` read what value is stored in `self`. It use self's values to
     create and run a operator, and check whether that op is OK or not.
-    
+
     See `test_add_two_op` for example usage.
     """
 
diff --git a/python/paddle/v2/framework/tests/test_dropout_op.py b/python/paddle/v2/framework/tests/test_dropout_op.py
new file mode 100644
index 0000000000..3f4738f614
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_dropout_op.py
@@ -0,0 +1,42 @@
+import unittest
+import numpy as np
+from gradient_checker import GradientChecker, create_op
+from op_test_util import OpTestMeta
+
+
+class TestDropoutOpProbZero(unittest.TestCase):
+    __metaclass__ = OpTestMeta
+
+    def setUp(self):
+        self.type = "dropout"
+        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
+        self.attrs = {'dropout_prob': 0.0}
+        self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))}
+
+
+class TestDropoutOpAllProbOne(unittest.TestCase):
+    __metaclass__ = OpTestMeta
+
+    def setUp(self):
+        self.type = "dropout"
+        self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
+        self.attrs = {'dropout_prob': 1.0}
+        self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))}
+
+
+class DropoutGradOpTest(GradientChecker):
+    def test_dropout_2d(self):
+        op = create_op("dropout")
+        inputs = {'X': np.random.random((10, 5)).astype("float32")}
+        self.compare_grad(op, inputs)
+        self.check_grad(op, inputs, set(["X"]), "Out")
+
+    def test_dropout_3d(self):
+        op = create_op("dropout")
+        inputs = {'X': np.random.random((10, 5, 4)).astype("float32")}
+        self.compare_grad(op, inputs)
+        self.check_grad(op, inputs, set(["X"]), "Out")
+
+
+if __name__ == '__main__':
+    unittest.main()