diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md
index 3e71a0a592..e3bee32f8e 100644
--- a/doc/howto/dev/new_op_cn.md
+++ b/doc/howto/dev/new_op_cn.md
@@ -169,6 +169,8 @@ class MulKernel : public framework::OpKernel {
 `MulKernel`需要重写`Compute`接口,该接口参数为`const framework::ExecutionContext& context`, `ExecutionContext`相比`InferShapeContext`增加了设备类型,同样可获取到输入输出和属性参数,`Compute`函数里写具体实现时。
    
 注意,不同设备(CPU、GPU)共享一个Op定义,是否则共享同一个`OpKernel`,取决于`Compute`调用的函数是否支持不同设备。`MulOp`的CPU、GPU实现共享同一个`Kernel`,`OpKernel`不共享的例子可以参考[`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43)。 
+
+为了使得`OpKernel`的计算过程书写较为简单,CPU、GPU的代码可以复用,我们通常借助Eigen unsupported Tensor模块来实现。关于在paddle中如何使用Eigen库,请参考对应的使用[文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/use_eigen_cn.md)
    
 到此前向Op实现完成,需要在`.cc`文件中注册该op和kernel。反向Op类的定义和Kernel定义与前向Op类似,这里不再重复。但注意,反向Op没有`ProtoMaker`。
    
@@ -188,9 +190,12 @@ REGISTER_OP_CPU_KERNEL(mul_grad,
   - `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。
   - `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulKernel`类。
 
-在 `.cu`文件中注册GPU Kernel。
+在 `.cu`文件中注册GPU Kernel。请注意,如果GPU Kernel的实现是基于Eigen unsupported模块,那么在 `.cu`的最前面请加上宏定义 `#define EIGEN_USE_GPU`
    
 ```c++
+// if use Eigen unsupported module before include head files
+#define EIGEN_USE_GPU
+
 namespace ops = paddle::operators;
 REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
 REGISTER_OP_GPU_KERNEL(mul_grad,
@@ -286,28 +291,50 @@ class TestMulOp(unittest.TestCase):
 
 反向Op单测继承自`GradientChecker`,而`GradientChecker`集成自`unittest.TestCase`,所以反向单测函数需要`test_`开头。
 
- ```
- class MulGradOpTest(GradientChecker):
-    def test_mul(self):
-        op = create_op("mul")
-        inputs = {
+```
+class TestMulGradOp(GradientChecker):
+    def setUp(self):
+        self.op = create_op("mul")
+        self.inputs = {
             'X': np.random.random((32, 84)).astype("float32"),
             'Y': np.random.random((84, 100)).astype("float32")
         }
-        self.compare_grad(op, inputs)      
+
+    def test_cpu_gpu_compare(self):
+        self.compare_grad(self.op, self.inputs)
+
+    def test_normal(self):
         # mul op will enlarge the relative error
         self.check_grad(
-            op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5)
- ```
+            self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
+
+    def test_ignore_x(self):
+        self.check_grad(
+            self.op,
+            self.inputs, ["Y"],
+            "Out",
+            max_relative_error=0.5,
+            no_grad_set={"X"})
+
+    def test_ignore_y(self):
+        self.check_grad(
+            self.op,
+            self.inputs, ["X"],
+            "Out",
+            max_relative_error=0.5,
+            no_grad_set={"Y"})
+```
+
+下面解释一些关键的地方:
 
    - 调用`create_op("mul")`创建反向Op对应的前向Op。
-   - 定义输入`inputs`。
    - 调用`compare_grad`函数对比CPU、GPU计算结果。
-   - 调用`check_grad`检查梯度稳定性,这里采用数值法检测梯度正确性。
-      - 第一个参数`op` : 前向op。
-      - 第二个参数`inputs` : 输入词典,词典的Key和`ProtoMaker`定义保持一致。
-      - 第三个参数`set(["X", "Y"])` : 指定对输入变量`X`、`Y`做梯度检测。
+   - `test_normal`中调用`check_grad`检查梯度稳定性,这里采用数值法检测梯度正确性。
+      - 第一个参数`self.op` : 前向Op。
+      - 第二个参数`self.inputs` : 输入词典,词典的Key和`ProtoMaker`定义保持一致。
+      - 第三个参数`["X", "Y"]` : 指定对输入变量`X`、`Y`做梯度检测。
       - 第四个参数`"Out"` : 指定前向网络最终的输出目标变量`Out`
+   - `test_ignore_x`和`test_ignore_y`分支测试只需要计算一个输入梯度的情况。
 
 
 ### 编译和执行 
diff --git a/doc/howto/dev/use_eigen_cn.md b/doc/howto/dev/use_eigen_cn.md
new file mode 100644
index 0000000000..1367323b71
--- /dev/null
+++ b/doc/howto/dev/use_eigen_cn.md
@@ -0,0 +1,146 @@
+## 在Paddle中如何使用Eigen
+
+神经网络本质上是一个计算图,计算需要的数据存放在`Tensor`中,而计算过程是由`Operartor`来描述的。在执行时,`Operator`调用对应`OpKernel`中的`Compute`接口,实现对`Tensor`的操作。
+
+
+### Eigen Tensor模块
+
+Eigen Tensor模块对element-wise计算提供了强大的支持,并且书写一份代码,可以同时在CPU、GPU执行。但Eigen Tensor是一个正在开发中的模块,因此可能测试不够完备,文档较少。
+
+关于Eigen Tensor模块的详细介绍请参考[文档1](https://github.com/RLovelett/eigen/blob/master/unsupported/Eigen/CXX11/src/Tensor/README.md) 和[文档2](https://bitbucket.org/eigen/eigen/src/default/unsupported/Eigen/CXX11/src/Tensor/README.md)
+
+
+### paddle::framework::Tensor
+
+Paddle Tensor定义在framework目录下,其主要接口如下:
+
+```cpp
+class Tensor {
+ public:
+  /*! Return a pointer to mutable memory block. */
+  template <typename T>
+  inline T* data();
+  
+  /**
+   * @brief   Return a pointer to mutable memory block.
+   * @note    If not exist, then allocation.
+   */
+  template <typename T>
+  inline T* mutable_data(platform::Place place);
+  
+  /**
+   * @brief     Return a pointer to mutable memory block.
+   *
+   * @param[in] dims    The dimensions of the memory block.
+   * @param[in] place   The place of the memory block.
+   *
+   * @note      If not exist, then allocation.
+   */
+  template <typename T>
+  inline T* mutable_data(DDim dims, platform::Place place);
+  
+  /*! Resize the dimensions of the memory block. */
+  inline Tensor& Resize(const DDim& dims);
+  
+  /*! Return the dimensions of the memory block. */
+  inline const DDim& dims() const;
+
+ private:  
+  /*! holds the memory block if allocated. */
+  std::shared_ptr<Placeholder> holder_;
+  
+  /*! points to dimensions of memory block. */
+  DDim dim_;
+};
+```
+
+`Placeholder`的作用是延迟分配内存,即我们可以先定义一个Tensor,然后使用Resize接口设置Tensor的大小,最后再调用mutable_data接口分配实际的内存。
+
+```cpp
+paddle::framework::Tensor t;
+paddle::platform::CPUPlace place;
+// set size first
+t.Resize({2, 3});
+// allocate memory on CPU later
+t.mutable_data(place);
+```
+
+### paddle::framework::Tensor使用样例
+下面以AddOp为例说明Tensor的使用过程:
+
+- InferShape
+
+在运行神经网络计算图时,我们先调用每个`Operator`的`InferShape`接口,根据输入Tensor的大小来设置输出Tensor的大小,`Resize`接口会被调用。
+
+```cpp
+void InferShape(const framework::InferShapeContext &ctx) const override {
+  PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
+                    ctx.Input<Tensor>("Y")->dims(),
+                    "Two input of Add Op's dimension must be same.");
+  ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
+}
+```
+
+
+- Run
+
+`Operator`的`Run`接口最终会调用对应`OpKernel`的`Compute`接口,在这时真正的分配内存,`mutable_data`接口会被调用。
+
+```cpp
+void Compute(const framework::ExecutionContext& context) const override {
+  auto* input0 = context.Input<Tensor>("X");
+  auto* input1 = context.Input<Tensor>("Y");
+  auto* output = context.Output<Tensor>("Out");
+
+  output->mutable_data<T>(context.GetPlace());
+
+  auto x = EigenVector<T>::Flatten(*input0);
+  auto y = EigenVector<T>::Flatten(*input1);
+  auto z = EigenVector<T>::Flatten(*output);
+
+  auto place = context.GetEigenDevice<Place>();
+
+  z.device(place) = x + y;
+}
+```
+
+
+### paddle::framework::Tensor到EigenTensor的转换
+
+如上一小节所示,在具体的计算中,我们需要先把输入Tensor和输出Tensor转换为Eigen支持的格式。我们在[eigen.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen.h)中提供了一些全局函数用来实现paddle::framework::Tensor到EigenTensor/EigenMatrix/EigenVector/EigenScalar的转换。
+
+以EigenTensor为例,做一个介绍
+
+```cpp
+Tensor t;
+float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
+for (int i = 0; i < 1 * 2 * 3; i++) {
+  p[i] = static_cast<float>(i);
+}
+
+EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
+```
+
+From是EigenTensor模板提供的一个接口,可以实现从paddle::framework::Tensor到对EigenTensor的转换。由于Tensor的rank是模板参数,因此在转换时需要显示的指定。
+
+在Eigen中,不同rank的Tensor是不同类型,Vector是rank为1的Tensor。需要额外注意的是,EigenVector<T>::From方法是把paddle中的一维Tensor转为Eigen的一维Tensor,在这里用EigenVector来表示;而EigenVector<T>::Flatten方法是把paddle中的一个Tensor进行reshape操作,压扁成为Eigen的一维Tensor,类型仍然为EigenVector。
+
+更多的转换方法请参考eigen_test.cc中的[单元测试](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen_test.cc)。
+
+
+
+### 实现计算
+
+当需要完成计算时,我们需要等式左边的EigenTensor调用device接口。在这里需要注意的是,这里的EigenTensor之间的运算只是改变了原有Tensor中的数据,而不会改变原有Tensor的shape信息。
+
+```cpp
+auto x = EigenVector<T>::Flatten(*input0);
+auto y = EigenVector<T>::Flatten(*input1);
+auto z = EigenVector<T>::Flatten(*output);
+auto place = context.GetEigenDevice<Place>();
+z.device(place) = x + y;
+```
+
+在这段代码中,input0/input1/output可以是任意维度的Tensor。我们调用了EigenVector的Flatten接口,把任意维度的Tensor转为了一维的EigenVector。而在计算结束之后,input0/input1/output的原有shape信息不变。如果想改变原有Tensor的shape信息,可以调用Resize接口进行改变。
+
+由于Eigen Tensor模块的文档较少,我们可以参考TensorFlow的[kernels](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/kernels)模块下的相关`OpKernel`的计算代码。
diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc
index 9eb07acdff..27132eaa0b 100644
--- a/paddle/framework/attribute.cc
+++ b/paddle/framework/attribute.cc
@@ -43,6 +43,10 @@ template <>
 AttrType AttrTypeID<std::vector<std::string>>() {
   return STRINGS;
 }
+template <>
+AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
+  return INT_PAIRS;
+}
 
 Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
   switch (attr_desc.type()) {
@@ -76,6 +80,14 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
       }
       return val;
     }
+    case paddle::framework::AttrType::INT_PAIRS: {
+      std::vector<std::pair<int, int>> val(attr_desc.int_pairs_size());
+      for (int i = 0; i < attr_desc.int_pairs_size(); ++i) {
+        val[i].first = attr_desc.int_pairs(i).first();
+        val[i].second = attr_desc.int_pairs(i).second();
+      }
+      return val;
+    }
   }
   PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
   return boost::blank();
diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h
index 08b47cabd4..071879a9d4 100644
--- a/paddle/framework/attribute.h
+++ b/paddle/framework/attribute.h
@@ -28,7 +28,8 @@ namespace paddle {
 namespace framework {
 
 typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
-                       std::vector<float>, std::vector<std::string>>
+                       std::vector<float>, std::vector<std::string>,
+                       std::vector<std::pair<int, int>>>
     Attribute;
 
 typedef std::unordered_map<std::string, Attribute> AttributeMap;
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index cfd3e8dfde..85b7de7974 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -21,16 +21,16 @@ namespace framework {
 /// @cond HIDDEN
 
 template <int i>
-Dim<i> make_dim(const int* d) {
+Dim<i> make_dim(const int64_t* d) {
   return Dim<i>(*d, make_dim<i - 1>(d + 1));
 }
 
 template <>
-Dim<1> make_dim<1>(const int* d) {
+Dim<1> make_dim<1>(const int64_t* d) {
   return Dim<1>(*d);
 }
 
-void make_ddim(DDim& ddim, const int* dims, int n) {
+void make_ddim(DDim& ddim, const int64_t* dims, int n) {
   switch (n) {
     case 1:
       ddim = make_dim<1>(dims);
@@ -67,13 +67,13 @@ void make_ddim(DDim& ddim, const int* dims, int n) {
 
 /// @endcond
 
-DDim make_ddim(std::initializer_list<int> dims) {
+DDim make_ddim(std::initializer_list<int64_t> dims) {
   DDim result(make_dim(0));
   make_ddim(result, dims.begin(), dims.size());
   return result;
 }
 
-DDim make_ddim(const std::vector<int>& dims) {
+DDim make_ddim(const std::vector<int64_t>& dims) {
   DDim result(make_dim(0));
   make_ddim(result, &dims[0], dims.size());
   return result;
@@ -81,12 +81,12 @@ DDim make_ddim(const std::vector<int>& dims) {
 
 /// @cond HIDDEN
 // XXX For some reason, putting this in an anonymous namespace causes errors
-class DynamicMutableIndexer : public boost::static_visitor<int&> {
+class DynamicMutableIndexer : public boost::static_visitor<int64_t&> {
  public:
   explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
 
   template <int D>
-  int& operator()(Dim<D>& dim) const {
+  int64_t& operator()(Dim<D>& dim) const {
     return dim[idx_];
   }
 
@@ -94,12 +94,12 @@ class DynamicMutableIndexer : public boost::static_visitor<int&> {
   int idx_;
 };
 
-class DynamicConstIndexer : public boost::static_visitor<int> {
+class DynamicConstIndexer : public boost::static_visitor<int64_t> {
  public:
   explicit DynamicConstIndexer(int idx) : idx_(idx) {}
 
   template <int D>
-  int operator()(const Dim<D>& dim) const {
+  int64_t operator()(const Dim<D>& dim) const {
     return dim[idx_];
   }
 
@@ -109,22 +109,22 @@ class DynamicConstIndexer : public boost::static_visitor<int> {
 
 /// @endcond
 
-int& DDim::operator[](int idx) {
+int64_t& DDim::operator[](int idx) {
   return boost::apply_visitor(DynamicMutableIndexer(idx), var);
 }
 
-int DDim::operator[](int idx) const {
+int64_t DDim::operator[](int idx) const {
   return boost::apply_visitor(DynamicConstIndexer(idx), var);
 }
 
-ssize_t DDim::size() const { return arity(*this); }
+int64_t DDim::size() const { return arity(*this); }
 
 bool DDim::operator==(DDim d) const {
   if (var.which() != d.getVar().which()) {
     return false;
   } else {
-    std::vector<int> v1 = vectorize(*this);
-    std::vector<int> v2 = vectorize(d);
+    std::vector<int64_t> v1 = vectorize(*this);
+    std::vector<int64_t> v2 = vectorize(d);
 
     for (unsigned int i = 0; i < v1.size(); i++) {
       if (v1[i] != v2[i]) {
@@ -139,10 +139,10 @@ bool DDim::operator==(DDim d) const {
 bool DDim::operator!=(DDim d) const { return !(*this == d); }
 
 DDim DDim::operator+(DDim d) const {
-  std::vector<int> v1 = vectorize(*this);
-  std::vector<int> v2 = vectorize(d);
+  std::vector<int64_t> v1 = vectorize(*this);
+  std::vector<int64_t> v2 = vectorize(d);
 
-  std::vector<int> v3;
+  std::vector<int64_t> v3;
 
   assert(v1.size() == v2.size());
 
@@ -154,10 +154,10 @@ DDim DDim::operator+(DDim d) const {
 }
 
 DDim DDim::operator*(DDim d) const {
-  std::vector<int> v1 = vectorize(*this);
-  std::vector<int> v2 = vectorize(d);
+  std::vector<int64_t> v1 = vectorize(*this);
+  std::vector<int64_t> v2 = vectorize(d);
 
-  std::vector<int> v3;
+  std::vector<int64_t> v3;
 
   assert(v1.size() == v2.size());
 
@@ -168,15 +168,15 @@ DDim DDim::operator*(DDim d) const {
   return make_ddim(v3);
 }
 
-int get(const DDim& ddim, int idx) { return ddim[idx]; }
+int64_t get(const DDim& ddim, int idx) { return ddim[idx]; }
 
 void set(DDim& ddim, int idx, int value) { ddim[idx] = value; }
 
 /// @cond HIDDEN
 struct VectorizeVisitor : public boost::static_visitor<> {
-  std::vector<int>& vector;
+  std::vector<int64_t>& vector;
 
-  explicit VectorizeVisitor(std::vector<int>& v) : vector(v) {}
+  explicit VectorizeVisitor(std::vector<int64_t>& v) : vector(v) {}
 
   template <typename T>
   void operator()(const T& t) {
@@ -188,31 +188,31 @@ struct VectorizeVisitor : public boost::static_visitor<> {
 };
 /// @endcond
 
-std::vector<int> vectorize(const DDim& ddim) {
-  std::vector<int> result;
+std::vector<int64_t> vectorize(const DDim& ddim) {
+  std::vector<int64_t> result;
   VectorizeVisitor visitor(result);
   boost::apply_visitor(visitor, ddim);
   return result;
 }
 
-struct ProductVisitor : public boost::static_visitor<ssize_t> {
+struct ProductVisitor : public boost::static_visitor<int64_t> {
   template <int D>
-  ssize_t operator()(const Dim<D>& dim) {
+  int64_t operator()(const Dim<D>& dim) {
     return product(dim);
   }
 };
 
-ssize_t product(const DDim& ddim) {
+int64_t product(const DDim& ddim) {
   ProductVisitor visitor;
   return boost::apply_visitor(visitor, ddim);
 }
 
 struct SliceVectorizeVisitor : public boost::static_visitor<> {
-  std::vector<int>& vector;
+  std::vector<int64_t>& vector;
   int begin;
   int end;
 
-  SliceVectorizeVisitor(std::vector<int>& v, int b, int e)
+  SliceVectorizeVisitor(std::vector<int64_t>& v, int b, int e)
       : vector(v), begin(b), end(e) {
     PADDLE_ENFORCE(begin < end,
                    "Begin index must be less than end index in ddim slice.");
@@ -240,7 +240,7 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
 };
 
 DDim slice_ddim(const DDim& dim, int begin, int end) {
-  std::vector<int> vec;
+  std::vector<int64_t> vec;
   vec.reserve(end - begin);
   SliceVectorizeVisitor visitor(vec, begin, end);
   boost::apply_visitor(visitor, dim);
@@ -280,7 +280,7 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
   return os;
 }
 
-DDim::DDim(std::initializer_list<int> init_list) {
+DDim::DDim(std::initializer_list<int64_t> init_list) {
   *this = make_ddim(init_list);
 }
 }  // namespace framework
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index 95f294b627..db30c52394 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -40,7 +40,7 @@ struct DDim {
   template <int D>
   explicit DDim(const Dim<D>& in) : var(in) {}
 
-  /*implicit*/ DDim(std::initializer_list<int> init_list);
+  /*implicit*/ DDim(std::initializer_list<int64_t> init_list);
 
   template <int D>
   DDim& operator=(const Dim<D>& in) {
@@ -48,8 +48,8 @@ struct DDim {
     return *this;
   }
 
-  int& operator[](int idx);
-  int operator[](int idx) const;
+  int64_t& operator[](int idx);
+  int64_t operator[](int idx) const;
 
   template <typename Visitor>
   typename Visitor::result_type apply_visitor(Visitor& visitor) {
@@ -71,15 +71,15 @@ struct DDim {
 
   DDim operator*(DDim d) const;
 
-  ssize_t size() const;
+  int64_t size() const;
 };
 
 /**
- * \brief Make a DDim from std::vector<int>
+ * \brief Make a DDim from std::vector<int64_t>
  *
  * \param dims An vector of ints. Must be sized between [1, 9]
  */
-DDim make_ddim(const std::vector<int>& dims);
+DDim make_ddim(const std::vector<int64_t>& dims);
 
 /**
  * \brief Make a DDim from an initializer list
@@ -87,14 +87,14 @@ DDim make_ddim(const std::vector<int>& dims);
  * \param dims An initializer list of ints. Must be sized between [1, 9]
  *
  */
-DDim make_ddim(std::initializer_list<int> dims);
+DDim make_ddim(std::initializer_list<int64_t> dims);
 
-int get(const DDim& dim, int idx);
+int64_t get(const DDim& dim, int idx);
 void set(DDim& dim, int idx, int val);
 
-std::vector<int> vectorize(const DDim& ddim);
+std::vector<int64_t> vectorize(const DDim& ddim);
 
-ssize_t product(const DDim& ddim);
+int64_t product(const DDim& ddim);
 
 /**
  * \brief Slice a ddim
diff --git a/paddle/framework/ddim_test.cc b/paddle/framework/ddim_test.cc
index 9d18a2972c..756232b1b5 100644
--- a/paddle/framework/ddim_test.cc
+++ b/paddle/framework/ddim_test.cc
@@ -12,7 +12,7 @@ TEST(DDim, Equality) {
   EXPECT_EQ(ddim[2], 5);
 
   // construct a DDim from a vector
-  std::vector<int> vec({9, 1, 5});
+  std::vector<int64_t> vec({9, 1, 5});
   paddle::framework::DDim vddim = paddle::framework::make_ddim(vec);
   EXPECT_EQ(ddim[0], 9);
   EXPECT_EQ(ddim[1], 1);
@@ -25,7 +25,7 @@ TEST(DDim, Equality) {
   EXPECT_EQ(paddle::framework::get(ddim, 0), 6);
 
   // vectorize a DDim
-  std::vector<int> res_vec = paddle::framework::vectorize(vddim);
+  std::vector<int64_t> res_vec = paddle::framework::vectorize(vddim);
   EXPECT_EQ(res_vec[0], 9);
   EXPECT_EQ(res_vec[1], 1);
   EXPECT_EQ(res_vec[2], 5);
diff --git a/paddle/framework/dim.h b/paddle/framework/dim.h
index 883fdc55eb..04d4b0e604 100644
--- a/paddle/framework/dim.h
+++ b/paddle/framework/dim.h
@@ -17,13 +17,13 @@ struct Dim {
   static constexpr int dimensions = i;
 
   template <typename... Args>
-  HOSTDEVICE Dim(int _head, Args... _tail) : head(_head), tail(_tail...) {
+  HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
     static_assert(sizeof...(_tail) == i - 1,
                   "Dim initialized with the wrong number of parameters");
   }
 
   HOSTDEVICE
-  Dim(int _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {}
+  Dim(int64_t _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {}
 
   HOSTDEVICE
   Dim() : head(0), tail() {}
@@ -31,12 +31,12 @@ struct Dim {
   /** Construct a Dim from a linear index and size.  Uses Fortran order
    * indexing. */
   HOSTDEVICE
-  Dim(int idx, const Dim<i>& size)
+  Dim(int64_t idx, const Dim<i>& size)
       : head(idx % size.head), tail(idx / size.head, size.tail) {}
 
   /** Construct a Dim with each dimension set to the given index */
   HOSTDEVICE
-  Dim(int idx) : head(idx), tail(idx) {}
+  Dim(int64_t idx) : head(idx), tail(idx) {}
 
   HOSTDEVICE
   bool operator==(const Dim<i>& o) const {
@@ -47,13 +47,13 @@ struct Dim {
   bool operator!=(const Dim<i>& o) const { return !(*this == o); }
 
   HOSTDEVICE
-  int& operator[](int idx);
+  int64_t& operator[](int idx);
   HOSTDEVICE
-  int operator[](int idx) const;
+  int64_t operator[](int idx) const;
 
   HOST std::string to_string() const;
 
-  int head;
+  int64_t head;
   Dim<i - 1> tail;
 };
 
@@ -63,7 +63,7 @@ struct Dim<1> {
   static constexpr int dimensions = 1;
 
   HOSTDEVICE
-  Dim(int _head) : head(_head) {}
+  Dim(int64_t _head) : head(_head) {}
 
   HOSTDEVICE
   Dim() : head(0) {}
@@ -86,11 +86,11 @@ struct Dim<1> {
   bool operator!=(const Dim<1>& o) const { return !(*this == o); }
 
   HOSTDEVICE
-  int& operator[](int idx);
+  int64_t& operator[](int idx);
   HOSTDEVICE
-  int operator[](int idx) const;
+  int64_t operator[](int idx) const;
 
-  int head;
+  int64_t head;
 };
 
 namespace {
@@ -100,12 +100,12 @@ template <int i>
 struct DimGetter {
   // Return a copy if Dim is const
   template <typename D>
-  HOSTDEVICE static int impl(const D& d) {
+  HOSTDEVICE static int64_t impl(const D& d) {
     return DimGetter<i - 1>::impl(d.tail);
   }
   // Return a reference if Dim is mutable
   template <typename D>
-  HOSTDEVICE static int& impl(D& d) {
+  HOSTDEVICE static int64_t& impl(D& d) {
     return DimGetter<i - 1>::impl(d.tail);
   }
 };
@@ -115,18 +115,18 @@ template <>
 struct DimGetter<0> {
   // Return a copy if Dim is const
   template <typename D>
-  HOSTDEVICE static int impl(const D& d) {
+  HOSTDEVICE static int64_t impl(const D& d) {
     return d.head;
   }
   // Return a reference if Dim is mutable
   template <typename D>
-  HOSTDEVICE static int& impl(D& d) {
+  HOSTDEVICE static int64_t& impl(D& d) {
     return d.head;
   }
 };
 
 template <int D>
-HOSTDEVICE int& indexer(Dim<D>& dim, int idx) {
+HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
 #ifndef __CUDA_ARCH__
   if (idx < 0) {
     throw std::invalid_argument("Tried to access a negative dimension");
@@ -141,7 +141,7 @@ HOSTDEVICE int& indexer(Dim<D>& dim, int idx) {
 }
 
 template <>
-HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) {
+HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) {
 #ifndef __CUDA_ARCH__
   if (idx != 0) {
     throw std::invalid_argument("Invalid index");
@@ -153,7 +153,7 @@ HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) {
 }
 
 template <int D>
-HOSTDEVICE int indexer(const Dim<D>& dim, int idx) {
+HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
 #ifndef __CUDA_ARCH__
   if (idx < 0) {
     throw std::invalid_argument("Tried to access a negative dimension");
@@ -168,7 +168,7 @@ HOSTDEVICE int indexer(const Dim<D>& dim, int idx) {
 }
 
 template <>
-HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) {
+HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) {
 #ifndef __CUDA_ARCH__
   if (idx != 0) {
     throw std::invalid_argument("Invalid index");
@@ -182,73 +182,76 @@ HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) {
 }  // namespace
 // Static access to constant Dim
 template <int i, int l>
-HOSTDEVICE int get(const Dim<l>& d) {
+HOSTDEVICE int64_t get(const Dim<l>& d) {
   return DimGetter<i>::impl(d);
 }
 
 // Static access to mutable Dim
 template <int i, int l>
-HOSTDEVICE int& get(Dim<l>& d) {
+HOSTDEVICE int64_t& get(Dim<l>& d) {
   return DimGetter<i>::impl(d);
 }
 
 // Dynamic access to constant Dim
 template <int l>
-HOSTDEVICE int Dim<l>::operator[](int i) const {
+HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
   return indexer(*this, i);
 }
 
 // Dynamic access to mutable Dim
 template <int l>
-HOSTDEVICE int& Dim<l>::operator[](int i) {
+HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
   return indexer(*this, i);
 }
 
 // Dynamic access to constant Dim
-inline HOSTDEVICE int Dim<1>::operator[](int i) const {
+inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const {
   return indexer(*this, i);
 }
 
 // Dynamic access to mutable Dim
-inline HOSTDEVICE int& Dim<1>::operator[](int i) { return indexer(*this, i); }
+inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) {
+  return indexer(*this, i);
+}
 
 // Dynamic access to constant Dim
 // without std::enable_if will try to instantiate this on get<0>(d)
 template <int l>
-HOSTDEVICE typename std::enable_if<(l > 0), int>::type get(const Dim<l>& d,
-                                                           int i) {
+HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l>& d,
+                                                               int i) {
   return d[i];
 }
 
 // Dynamic access to mutable Dim
 template <int l>
-HOSTDEVICE typename std::enable_if<(l > 0), int&>::type get(Dim<l>& d, int i) {
+HOSTDEVICE typename std::enable_if<(l > 0), int64_t&>::type get(Dim<l>& d,
+                                                                int i) {
   return d[i];
 }
 
 // Dot product of two dims
 template <int i>
-HOSTDEVICE int linearize(const Dim<i>& a, const Dim<i>& b) {
+HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
   return a.head * b.head + linearize(a.tail, b.tail);
 }
 
 // Base case dot product of two Dims
 // Notice it is inline because it is no longer a template
 template <>
-HOSTDEVICE inline int linearize(const Dim<1>& a, const Dim<1>& b) {
+HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) {
   return a.head * b.head;
 }
 
 // Product of a Dim
 template <int i>
-HOSTDEVICE int product(const Dim<i>& a, int prod = 1) {
+HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
   return prod * a.head * product(a.tail);
 }
 
 // Base case product of a Dim
 // Notice it is inline because it is no longer a template
 template <>
-HOSTDEVICE inline int product(const Dim<1>& a, int prod) {
+HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) {
   return prod * a.head;
 }
 
diff --git a/paddle/framework/dim_test.cu b/paddle/framework/dim_test.cu
index 3898d0a447..0a6a87669c 100644
--- a/paddle/framework/dim_test.cu
+++ b/paddle/framework/dim_test.cu
@@ -8,7 +8,7 @@ __global__ void test(paddle::framework::Dim<2>* o) {
   o[0] = paddle::framework::make_dim(5, 6);
 }
 
-__global__ void dyn_idx_gpu(int* o) {
+__global__ void dyn_idx_gpu(int64_t* o) {
   auto d = paddle::framework::make_dim(5, 6);
   o[0] = d[1];
 }
@@ -47,9 +47,9 @@ TEST(Dim, Equality) {
   EXPECT_EQ(b[1], 11);
 
   // dynamic access on GPU
-  thrust::device_vector<int> r(1);
+  thrust::device_vector<int64_t> r(1);
   dyn_idx_gpu<<<1, 1>>>(thrust::raw_pointer_cast(r.data()));
-  int res = r[0];
+  int64_t res = r[0];
   EXPECT_EQ(res, 6);
 
   // ex_prefix_mul
diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h
index a4667cc51f..2d8d9ae10c 100644
--- a/paddle/framework/eigen.h
+++ b/paddle/framework/eigen.h
@@ -28,7 +28,7 @@ struct EigenDim {
   static Type From(const DDim& dims) {
     PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
     Type ret;
-    for (int d = 0; d < arity(dims); d++) {
+    for (int64_t d = 0; d < arity(dims); d++) {
       ret[d] = dims[d];
     }
     return ret;
diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto
index ae44a1ffd4..368136a972 100644
--- a/paddle/framework/framework.proto
+++ b/paddle/framework/framework.proto
@@ -22,8 +22,14 @@ enum AttrType {
   INTS = 3;
   FLOATS = 4;
   STRINGS = 5;
+  INT_PAIRS = 6;
 }
 
+message IntPair {
+  required int32 first = 1;
+  required int32 second = 2;
+};
+
 // OpDesc describes an instance of a C++ framework::OperatorBase
 // derived class type.
 message OpDesc {
@@ -37,6 +43,7 @@ message OpDesc {
     repeated int32 ints = 6;
     repeated float floats = 7;
     repeated string strings = 8;
+    repeated IntPair int_pairs = 9;
   };
 
   message Var {
diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc
index 50c45919c5..b43f6a8cc5 100644
--- a/paddle/framework/op_registry_test.cc
+++ b/paddle/framework/op_registry_test.cc
@@ -174,36 +174,4 @@ TEST(OpRegistry, CustomChecker) {
   op->Run(scope, dev_ctx);
   int test_attr = op->GetAttr<int>("test_attr");
   ASSERT_EQ(test_attr, 4);
-}
-
-class TestAttrProtoMaker : public pd::OpProtoAndCheckerMaker {
- public:
-  TestAttrProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker)
-      : OpProtoAndCheckerMaker(proto, op_checker) {
-    AddAttr<float>("scale", "scale of test op");
-    AddAttr<float>("scale", "scale of test op");
-  }
-};
-
-TEST(ProtoMaker, DuplicatedAttr) {
-  pd::OpProto op_proto;
-  pd::OpAttrChecker op_checker;
-  auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
-  ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
-}
-
-class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
- public:
-  TestInOutProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker)
-      : OpProtoAndCheckerMaker(proto, op_checker) {
-    AddInput("input", "input of test op");
-    AddInput("input", "input of test op");
-  }
-};
-
-TEST(ProtoMaker, DuplicatedInOut) {
-  pd::OpProto op_proto;
-  pd::OpAttrChecker op_checker;
-  auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
-  ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
-}
+}
\ No newline at end of file
diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc
index f7c9e6b196..8a1970c7a8 100644
--- a/paddle/framework/operator_test.cc
+++ b/paddle/framework/operator_test.cc
@@ -263,4 +263,38 @@ TEST(Operator, Clone) {
   OperatorClone a("ABC", {}, {}, {});
   auto b = a.Clone();
   ASSERT_EQ(a.Type(), b->Type());
+}
+
+class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
+ public:
+  TestAttrProtoMaker(paddle::framework::OpProto* proto,
+                     paddle::framework::OpAttrChecker* op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddAttr<float>("scale", "scale of test op");
+    AddAttr<float>("scale", "scale of test op");
+  }
+};
+
+TEST(ProtoMaker, DuplicatedAttr) {
+  paddle::framework::OpProto op_proto;
+  paddle::framework::OpAttrChecker op_checker;
+  auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
+  ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
+}
+
+class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
+ public:
+  TestInOutProtoMaker(paddle::framework::OpProto* proto,
+                      paddle::framework::OpAttrChecker* op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("input", "input of test op");
+    AddInput("input", "input of test op");
+  }
+};
+
+TEST(ProtoMaker, DuplicatedInOut) {
+  paddle::framework::OpProto op_proto;
+  paddle::framework::OpAttrChecker op_checker;
+  auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
+  ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
 }
\ No newline at end of file
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index 7893e233b7..94f436294f 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -58,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place) {
                     "Tensor's numel must be larger than zero to call "
                     "Tensor::mutable_data. Call Tensor::set_dim first.");
   /* some versions of boost::variant don't have operator!= */
-  size_t size = product(dims_) * sizeof(T);
+  int64_t size = product(dims_) * sizeof(T);
   if (holder_ == nullptr || !(holder_->place() == place) ||
       holder_->size() < size + offset_) {
     if (platform::is_cpu_place(place)) {
@@ -131,7 +131,7 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
   PADDLE_ENFORCE_LT(begin_idx, end_idx,
                     "Begin index must be less than end index.");
   PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
-  int base = product(dims_) / dims_[0];
+  size_t base = product(dims_) / dims_[0];
   Tensor dst;
   dst.holder_ = holder_;
   DDim dst_dims = dims_;
diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index e5efcccb0e..25dbd236e6 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -56,7 +56,7 @@ list(REMOVE_ITEM GENERAL_OPS
 op_library(net_op SRCS net_op.cc)
 op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op)
 op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
-op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc 
+op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
   DEPS framework_proto tensor operator net_op)
 op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
 
diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc
new file mode 100644
index 0000000000..c033af3b74
--- /dev/null
+++ b/paddle/operators/cos_sim_op.cc
@@ -0,0 +1,107 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/operators/cos_sim_op.h"
+
+namespace paddle {
+namespace operators {
+
+using framework::Tensor;
+
+class CosSimOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext &ctx) const override {
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null.");
+    PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
+                      ctx.Input<Tensor>("Y")->dims(),
+                      "Dimensions of Input(X) and Input(Y) must be the same.");
+
+    auto dims = ctx.Input<Tensor>("X")->dims();
+    ctx.Output<Tensor>("Out")->Resize({dims[0], 1});
+    ctx.Output<Tensor>("XNorm")->Resize({dims[0], 1});
+    ctx.Output<Tensor>("YNorm")->Resize({dims[0], 1});
+  }
+};
+
+class CosSimOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("X", "The first input of cos_sim op.");
+    AddInput("Y", "The second input of cos_sim op.");
+    AddOutput("Out", "The output of cos_sim op.");
+    AddOutput("XNorm", "Row norm of the first input.").AsIntermediate();
+    AddOutput("YNorm", "Row norm of the second input.").AsIntermediate();
+
+    AddComment(R"DOC(
+Cosine Similarity Operator.
+
+The equation is: Out = X^T * Y / (sqrt(X^T * X) * sqrt(Y^T * Y))
+)DOC");
+  }
+};
+
+class CosSimOpGrad : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext &ctx) const override {
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"),
+                            "Input(XNorm) must not be null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"),
+                            "Input(YNorm) must not be null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
+                            "Input(Out@GRAD) must not be null.");
+
+    auto x_dims = ctx.Input<Tensor>("X")->dims();
+    auto y_dims = ctx.Input<Tensor>("Y")->dims();
+    auto xnorm_dims = ctx.Input<Tensor>("XNorm")->dims();
+    auto ynorm_dims = ctx.Input<Tensor>("YNorm")->dims();
+    auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
+    PADDLE_ENFORCE_EQ(x_dims, y_dims,
+                      "Dimensions of Input(X) and Input(Y) must be the same.");
+    PADDLE_ENFORCE_EQ(xnorm_dims[0], x_dims[0],
+                      "1st dimension of XNorm must equal that of Input(X).");
+    PADDLE_ENFORCE_EQ(xnorm_dims[1], 1, "2st dimension of XNorm must be one.");
+    PADDLE_ENFORCE_EQ(ynorm_dims[0], y_dims[0],
+                      "1st dimension of YNorm must equal that of Input(Y).");
+    PADDLE_ENFORCE_EQ(ynorm_dims[1], 1, "2st dimension of YNorm must be one.");
+    PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0],
+                      "1st dimension of Out@GRAD must equal that of Input(X)");
+    PADDLE_ENFORCE_EQ(out_dims[1], 1, "1st dimension of Out@GRAD must be one.");
+
+    auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
+    auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
+    if (x_grad) x_grad->Resize(x_dims);
+    if (y_grad) y_grad->Resize(y_dims);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(cos_sim, ops::CosSimOp, ops::CosSimOpMaker, cos_sim_grad,
+            ops::CosSimOpGrad);
+REGISTER_OP_CPU_KERNEL(cos_sim,
+                       ops::CosSimKernel<paddle::platform::CPUPlace, float>);
+REGISTER_OP_CPU_KERNEL(
+    cos_sim_grad, ops::CosSimGradKernel<paddle::platform::CPUPlace, float>);
diff --git a/paddle/operators/cos_sim_op.cu b/paddle/operators/cos_sim_op.cu
new file mode 100644
index 0000000000..0cb8fd26de
--- /dev/null
+++ b/paddle/operators/cos_sim_op.cu
@@ -0,0 +1,22 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/cos_sim_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(cos_sim,
+                       ops::CosSimKernel<paddle::platform::GPUPlace, float>);
+REGISTER_OP_GPU_KERNEL(
+    cos_sim_grad, ops::CosSimGradKernel<paddle::platform::GPUPlace, float>);
diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h
new file mode 100644
index 0000000000..9e3ff26815
--- /dev/null
+++ b/paddle/operators/cos_sim_op.h
@@ -0,0 +1,104 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
+
+template <typename Place, typename T>
+class CosSimKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto* input_x = context.Input<Tensor>("X");
+    auto* input_y = context.Input<Tensor>("Y");
+    auto* output_z = context.Output<Tensor>("Out");
+    auto* output_x_norm = context.Output<Tensor>("XNorm");
+    auto* output_y_norm = context.Output<Tensor>("YNorm");
+
+    output_z->mutable_data<T>(context.GetPlace());
+    output_x_norm->mutable_data<T>(context.GetPlace());
+    output_y_norm->mutable_data<T>(context.GetPlace());
+
+    auto dims = input_x->dims();
+    int size = static_cast<int>(framework::product(dims));
+    auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
+    auto x = EigenMatrix<T>::From(*input_x, new_dims);
+    auto y = EigenMatrix<T>::From(*input_y, new_dims);
+    auto z = EigenMatrix<T>::From(*output_z);
+    auto x_norm = EigenMatrix<T>::From(*output_x_norm);
+    auto y_norm = EigenMatrix<T>::From(*output_y_norm);
+
+    auto place = context.GetEigenDevice<Place>();
+    auto xy = (x * y).sum(Eigen::array<int, 1>({1}));
+    x_norm.device(place) = x.square().sum(Eigen::array<int, 1>({1})).sqrt();
+    y_norm.device(place) = y.square().sum(Eigen::array<int, 1>({1})).sqrt();
+    z.device(place) = xy / x_norm / y_norm;
+  }
+};
+
+template <typename Place, typename T>
+class CosSimGradKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    auto* input_x = context.Input<Tensor>("X");
+    auto* input_y = context.Input<Tensor>("Y");
+    auto* input_z = context.Input<Tensor>("Out");
+    auto* input_x_norm = context.Input<Tensor>("XNorm");
+    auto* input_y_norm = context.Input<Tensor>("YNorm");
+    auto* output_grad_x = context.Output<Tensor>(framework::GradVarName("X"));
+    auto* output_grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
+    auto* input_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
+
+    auto dims = input_x->dims();
+    int size = static_cast<int>(framework::product(dims));
+    auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
+    auto x = EigenMatrix<T>::From(*input_x, new_dims);
+    auto y = EigenMatrix<T>::From(*input_y, new_dims);
+    auto z = EigenMatrix<T>::From(*input_z);
+    auto x_norm = EigenMatrix<T>::From(*input_x_norm);
+    auto y_norm = EigenMatrix<T>::From(*input_y_norm);
+    auto dz = EigenMatrix<T>::From(*input_grad_z);
+
+    Eigen::DSizes<int, 2> bcast(1, new_dims[1]);
+    auto z_bcast = z.broadcast(bcast);
+    auto dz_bcast = dz.broadcast(bcast);
+    auto place = context.GetEigenDevice<Place>();
+    auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast);
+    auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast);
+    auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast);
+    if (output_grad_x) {
+      output_grad_x->mutable_data<T>(context.GetPlace());
+      auto dx = EigenMatrix<T>::From(*output_grad_x, new_dims);
+      dx.device(place) =
+          dz_bcast * (y / norm_prod_bcast - z_bcast * x / x_snorm_bcast);
+    }
+    if (output_grad_y) {
+      output_grad_y->mutable_data<T>(context.GetPlace());
+      auto dy = EigenMatrix<T>::From(*output_grad_y, new_dims);
+      dy.device(place) =
+          dz_bcast * (x / norm_prod_bcast - z_bcast * y / y_snorm_bcast);
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc
index 056447901d..8bb61275ba 100644
--- a/paddle/operators/gaussian_random_op.cc
+++ b/paddle/operators/gaussian_random_op.cc
@@ -31,8 +31,8 @@ class CPUGaussianRandomKernel : public framework::OpKernel {
     }
     engine.seed(seed);
     std::normal_distribution<T> dist(mean, std);
-    ssize_t size = framework::product(tensor->dims());
-    for (ssize_t i = 0; i < size; ++i) {
+    int64_t size = framework::product(tensor->dims());
+    for (int64_t i = 0; i < size; ++i) {
       data[i] = dist(engine);
     }
   }
@@ -46,9 +46,14 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
   void InferShape(const framework::InferShapeContext& context) const override {
     auto* tensor = context.Output<framework::Tensor>("Out");
     auto dims = GetAttr<std::vector<int>>("dims");
+    std::vector<int64_t> temp;
+    temp.reserve(dims.size());
+    for (auto dim : dims) {
+      temp.push_back(static_cast<int64_t>(dim));
+    }
     PADDLE_ENFORCE(dims.size() > 0UL,
                    "dims can be one int or array. dims must be set.");
-    tensor->Resize(framework::make_ddim(dims));
+    tensor->Resize(framework::make_ddim(temp));
   }
 };
 
diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h
index 4da8079b91..877b36cef4 100644
--- a/paddle/operators/lookup_table_op.h
+++ b/paddle/operators/lookup_table_op.h
@@ -30,12 +30,12 @@ class LookupTableKernel : public framework::OpKernel {
     auto ids_t = context.Input<Tensor>("Ids");      // int tensor
     auto output_t = context.Output<Tensor>("Out");  // float tensor
 
-    size_t N = table_t->dims()[0];
-    size_t D = table_t->dims()[1];
+    int N = table_t->dims()[0];
+    int D = table_t->dims()[1];
     auto ids = ids_t->data<int32_t>();
     auto table = table_t->data<T>();
     auto output = output_t->mutable_data<T>(context.GetPlace());
-    for (size_t i = 0; i < product(ids_t->dims()); ++i) {
+    for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
       PADDLE_ENFORCE_LT(ids[i], N);
       PADDLE_ENFORCE_GE(ids[i], 0);
       memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
@@ -51,8 +51,8 @@ class LookupTableGradKernel : public framework::OpKernel {
     auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
     auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
 
-    size_t N = d_table_t->dims()[0];
-    size_t D = d_table_t->dims()[1];
+    int N = d_table_t->dims()[0];
+    int D = d_table_t->dims()[1];
     auto ids = ids_t->data<int32_t>();
     const T* d_output = d_output_t->data<T>();
     T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
@@ -61,10 +61,10 @@ class LookupTableGradKernel : public framework::OpKernel {
     t.device(context.GetEigenDevice<platform::CPUPlace>()) =
         t.constant(static_cast<T>(0));
 
-    for (size_t i = 0; i < product(ids_t->dims()); ++i) {
+    for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
       PADDLE_ENFORCE_LT(ids[i], N);
       PADDLE_ENFORCE_GE(ids[i], 0);
-      for (size_t j = 0; j < D; ++j) {
+      for (int j = 0; j < D; ++j) {
         d_table[ids[i] * D + j] += d_output[i * D + j];
       }
     }
diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc
index 5b8b5f6c11..28a47cdff2 100644
--- a/paddle/operators/mul_op.cc
+++ b/paddle/operators/mul_op.cc
@@ -75,8 +75,8 @@ class MulOpGrad : public framework::OperatorWithKernel {
     PADDLE_ENFORCE(y_dims[1] == out_dims[1],
                    "Out@GRAD M X N must equal to Y dims 1, N ");
 
-    x_grad->Resize(x_dims);
-    y_grad->Resize(y_dims);
+    if (x_grad) x_grad->Resize(x_dims);
+    if (y_grad) y_grad->Resize(y_dims);
   }
 };
 
diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h
index 8facc02814..05a79e13b3 100644
--- a/paddle/operators/mul_op.h
+++ b/paddle/operators/mul_op.h
@@ -31,13 +31,13 @@ template <typename Place, typename T>
 class MulKernel : public framework::OpKernel {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
-    auto* X = context.Input<Tensor>("X");
-    auto* Y = context.Input<Tensor>("Y");
-    auto* Z = context.Output<Tensor>("Out");
-    Z->mutable_data<T>(context.GetPlace());
+    auto* x = context.Input<Tensor>("X");
+    auto* y = context.Input<Tensor>("Y");
+    auto* z = context.Output<Tensor>("Out");
+    z->mutable_data<T>(context.GetPlace());
     auto* device_context =
         const_cast<platform::DeviceContext*>(context.device_context_);
-    math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
+    math::matmul<Place, T>(*x, false, *y, false, 1, z, 0, device_context);
   }
 };
 
@@ -45,20 +45,24 @@ template <typename Place, typename T>
 class MulGradKernel : public framework::OpKernel {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
-    auto* X = ctx.Input<Tensor>("X");
-    auto* Y = ctx.Input<Tensor>("Y");
-    auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
+    auto* x = ctx.Input<Tensor>("X");
+    auto* y = ctx.Input<Tensor>("Y");
+    auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
 
-    auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
-    auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
-    dX->mutable_data<T>(ctx.GetPlace());
-    dY->mutable_data<T>(ctx.GetPlace());
+    auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
+    auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
     auto* device_context =
         const_cast<platform::DeviceContext*>(ctx.device_context_);
-    // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
-    math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
-    // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
-    math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
+    if (dx) {
+      dx->mutable_data<T>(ctx.GetPlace());
+      // dx = dout * y'. dx: M x K, dout : M x N, y : K x N
+      math::matmul<Place, T>(*dout, false, *y, true, 1, dx, 0, device_context);
+    }
+    if (dy) {
+      dy->mutable_data<T>(ctx.GetPlace());
+      // dy = x' * dout. dy K x N, dout : M x N, x : M x K
+      math::matmul<Place, T>(*x, true, *dout, false, 1, dy, 0, device_context);
+    }
   }
 };
 
diff --git a/paddle/operators/rnn/recurrent_op_utils.cc b/paddle/operators/rnn/recurrent_op_utils.cc
index a9b65c30f2..69e723b401 100644
--- a/paddle/operators/rnn/recurrent_op_utils.cc
+++ b/paddle/operators/rnn/recurrent_op_utils.cc
@@ -61,7 +61,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
       PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
                      outlinks[i].internal);
       f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
-      std::vector<int> dims_vec = vectorize(step_dims);
+      std::vector<int64_t> dims_vec = vectorize(step_dims);
       dims_vec.insert(dims_vec.begin(), seq_len);
       output->Resize(f::make_ddim(dims_vec));
     } else {
diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc
index 6825dce332..30b4b40431 100644
--- a/paddle/operators/rowwise_add_op.cc
+++ b/paddle/operators/rowwise_add_op.cc
@@ -64,8 +64,10 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
     auto dims0 = ctx.Input<Tensor>("X")->dims();
     auto dims1 = ctx.Input<Tensor>("b")->dims();
     PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
-    ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
-    ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
+    auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
+    auto *db = ctx.Output<Tensor>(framework::GradVarName("b"));
+    if (dx) dx->Resize(dims0);
+    if (db) db->Resize(dims1);
   }
 };
 
diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h
index 1cbd8bb31a..4e926d9f29 100644
--- a/paddle/operators/rowwise_add_op.h
+++ b/paddle/operators/rowwise_add_op.h
@@ -51,20 +51,24 @@ template <typename Place, typename T>
 class RowwiseAddGradKernel : public framework::OpKernel {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
-    auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
-    auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
+    auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
+    auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
     auto* db = context.Output<Tensor>(framework::GradVarName("b"));
-    dX->mutable_data<T>(context.GetPlace());
-    db->mutable_data<T>(context.GetPlace());
 
-    auto OutGrad = EigenMatrix<T>::From(*dOut);
+    auto out_grad = EigenMatrix<T>::From(*dout);
     auto place = context.GetEigenDevice<Place>();
-    EigenMatrix<T>::From(*dX).device(place) = OutGrad;
+    if (dx) {
+      dx->mutable_data<T>(context.GetPlace());
+      EigenMatrix<T>::From(*dx).device(place) = out_grad;
+    }
 
-    // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
-    // colwise add
-    Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
-    EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims);
+    if (db) {
+      db->mutable_data<T>(context.GetPlace());
+      // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
+      // colwise add
+      Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
+      EigenVector<T>::Flatten(*db).device(place) = out_grad.sum(dims);
+    }
   }
 };
 }  // namespace operators
diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc
index 40c51a64c4..7d062ad67c 100644
--- a/paddle/operators/softmax_op.cc
+++ b/paddle/operators/softmax_op.cc
@@ -24,7 +24,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
  protected:
   void InferShape(const framework::InferShapeContext &ctx) const override {
     PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
-                   "The input of softmax op must be matrix");
+                   "The input of softmax op must be a matrix.");
     ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
   }
 };
@@ -34,9 +34,27 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
   SoftmaxOpMaker(framework::OpProto *proto,
                  framework::OpAttrChecker *op_checker)
       : OpProtoAndCheckerMaker(proto, op_checker) {
-    AddInput("X", "input of softmax");
-    AddOutput("Y", "output of softmax");
-    AddComment("Softmax Op");
+    AddInput("X",
+             "The input tensor of softmax. "
+             "2-D with shape [batch_size, input_feature_dimensions].");
+    AddOutput("Y", "The normalized values with the same shape as X.");
+    AddComment(R"DOC(
+The input of softmax operator is a 2-D tensor with shape N x K (N is the
+batch_size, K is the dimension of input feature). The output tensor has the
+same shape as the input tensor.
+
+For each row of the input tensor, the softmax operator squashes the
+K-dimensional vector of arbitrary real values to a K-dimensional vector of real
+values in the range [0, 1] that add up to 1. Specifically, it computes the
+exponential of the given dimension and the sum of exponential values of all
+the other dimensions in the K-dimensional vector input. Then the ratio of the
+exponential of the given dimension and the sum of exponential values of all
+the other dimensions is the output of the softmax operator.
+
+For each row `i` and each column `j` in X, we have:
+    Y[i, j] = exp(X[i, j]) / sum_j(exp(X[i, j]))
+
+)DOC");
   }
 };
 
diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc
index 2d943c4508..40cef8942a 100644
--- a/paddle/operators/uniform_random_op.cc
+++ b/paddle/operators/uniform_random_op.cc
@@ -35,8 +35,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
     std::uniform_real_distribution<T> dist(
         static_cast<T>(context.GetAttr<float>("min")),
         static_cast<T>(context.GetAttr<float>("max")));
-    ssize_t size = framework::product(tensor->dims());
-    for (ssize_t i = 0; i < size; ++i) {
+    int64_t size = framework::product(tensor->dims());
+    for (int64_t i = 0; i < size; ++i) {
       data[i] = dist(engine);
     }
   }
@@ -52,7 +52,12 @@ class UniformRandomOp : public framework::OperatorWithKernel {
                    "uniform_random's min must less then max");
     auto* tensor = ctx.Output<framework::Tensor>("Out");
     auto dims = GetAttr<std::vector<int>>("dims");
-    tensor->Resize(framework::make_ddim(dims));
+    std::vector<int64_t> temp;
+    temp.reserve(dims.size());
+    for (auto dim : dims) {
+      temp.push_back(static_cast<int64_t>(dim));
+    }
+    tensor->Resize(framework::make_ddim(temp));
   }
 };
 
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index e2ea5c92af..1b76dc0c17 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -46,6 +46,7 @@ USE_OP(lookup_table);
 USE_OP(scale);
 USE_NO_KERNEL_OP(identity);
 USE_OP(minus);
+USE_OP(cos_sim);
 USE_CPU_ONLY_OP(gather);
 USE_CPU_ONLY_OP(scatter);
 USE_OP(crop);
@@ -77,7 +78,7 @@ PYBIND11_PLUGIN(core) {
       .def("get_dims",
            [](const Tensor &self) { return vectorize(self.dims()); })
       .def("set_dims",
-           [](Tensor &self, const std::vector<int> &dim) {
+           [](Tensor &self, const std::vector<int64_t> &dim) {
              self.Resize(make_ddim(dim));
            })
       .def("alloc_float",
diff --git a/paddle/pybind/tensor_py.h b/paddle/pybind/tensor_py.h
index 39ba60b4dc..95171acf72 100644
--- a/paddle/pybind/tensor_py.h
+++ b/paddle/pybind/tensor_py.h
@@ -85,7 +85,7 @@ void PyCPUTensorSetFromArray(
     framework::Tensor &self,
     py::array_t<T, py::array::c_style | py::array::forcecast> array,
     paddle::platform::CPUPlace &place) {
-  std::vector<int> dims;
+  std::vector<int64_t> dims;
   dims.reserve(array.ndim());
   for (size_t i = 0; i < array.ndim(); ++i) {
     dims.push_back((int)array.shape()[i]);
@@ -102,7 +102,7 @@ void PyCUDATensorSetFromArray(
     framework::Tensor &self,
     py::array_t<T, py::array::c_style | py::array::forcecast> array,
     paddle::platform::GPUPlace &place) {
-  std::vector<int> dims;
+  std::vector<int64_t> dims;
   dims.reserve(array.ndim());
   for (size_t i = 0; i < array.ndim(); ++i) {
     dims.push_back((int)array.shape()[i]);
diff --git a/python/paddle/trainer/PyDataProvider2.py b/python/paddle/trainer/PyDataProvider2.py
index 7e305e2cd9..248da4ae8d 100644
--- a/python/paddle/trainer/PyDataProvider2.py
+++ b/python/paddle/trainer/PyDataProvider2.py
@@ -27,6 +27,14 @@ class SequenceType(object):
     SEQUENCE = 1
     SUB_SEQUENCE = 2
 
+    @classmethod
+    def tostring(cls, value):
+        for k in cls.__dict__:
+            if not k.startswith('__'):
+                if getattr(cls, k) == value:
+                    return cls.__name__ + '.' + k
+        return 'INVALID(' + str(value) + ')'
+
 
 # TODO(yuyang18): Add string data type here.
 class DataType(object):
@@ -35,6 +43,14 @@ class DataType(object):
     SparseValue = 2
     Index = 3
 
+    @classmethod
+    def tostring(cls, value):
+        for k in cls.__dict__:
+            if not k.startswith('__'):
+                if getattr(cls, k) == value:
+                    return cls.__name__ + '.' + k
+        return 'INVALID(' + str(value) + ')'
+
 
 class CacheType(object):
     NO_CACHE = 0  # No cache at all
@@ -69,6 +85,26 @@ class InputType(object):
         self.seq_type = seq_type
         self.type = tp
 
+    def __repr__(self):
+        """
+        Return a human readable representation like 'InputType(dim=25921, 
+            seq_type=SequenceType.NO_SEQUENCE, type=DataType.Dense)'
+        """
+        repr_str = type(self).__name__
+        repr_str += '('
+        serialize_func_map = {
+            'dim': repr,
+            'seq_type': SequenceType.tostring,
+            'type': DataType.tostring
+        }
+        for idx, k in enumerate(self.__slots__):
+            if idx != 0:
+                repr_str += ', '
+            repr_str += (
+                k + '=' + serialize_func_map.get(k, repr)(getattr(self, k)))
+        repr_str += ')'
+        return repr_str
+
 
 def dense_slot(dim, seq_type=SequenceType.NO_SEQUENCE):
     """
diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py
index e7e932f6fe..0349407a85 100644
--- a/python/paddle/v2/framework/op.py
+++ b/python/paddle/v2/framework/op.py
@@ -94,9 +94,14 @@ class OpDescCreationMethod(object):
                     new_attr.floats.extend(user_defined_attr)
                 elif attr.type == framework_pb2.STRINGS:
                     new_attr.strings.extend(user_defined_attr)
+                elif attr.type == framework_pb2.INT_PAIRS:
+                    for p in user_defined_attr:
+                        pair = new_attr.pairs.add()
+                        pair.first = p[0]
+                        pair.second = p[1]
                 else:
                     raise NotImplementedError("Not support attribute type " +
-                                              attr.type)
+                                              str(attr.type))
 
         return op_desc
 
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index 661ebd8964..e0f77d7973 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -4,6 +4,7 @@ py_test(test_scope SRCS test_scope.py)
 
 py_test(test_tensor SRCS test_tensor.py)
 py_test(test_mul_op SRCS test_mul_op.py)
+py_test(test_cos_sim_op SRCS test_cos_sim_op.py)
 
 py_test(test_mean_op SRCS test_mean_op.py)
 
diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py
index 518f828bac..fdb06b7988 100644
--- a/python/paddle/v2/framework/tests/gradient_checker.py
+++ b/python/paddle/v2/framework/tests/gradient_checker.py
@@ -36,13 +36,13 @@ def get_numeric_gradient(op,
                          in_place=False):
     """
     Get Numeric Gradient for an operator's input.
-    
-    :param op: C++ operator instance, could be an network 
-    :param input_values: The input variables. Should be an dictionary, key is 
+
+    :param op: C++ operator instance, could be an network
+    :param input_values: The input variables. Should be an dictionary, key is
     variable name. Value is numpy array.
-    :param output_name: The final output variable name. 
+    :param output_name: The final output variable name.
     :param input_to_check: The input variable need to get gradient.
-    :param delta: The perturbation value for numeric gradient method. The 
+    :param delta: The perturbation value for numeric gradient method. The
     smaller delta is, the more accurate result will get. But if that delta is
      too small, it could occur numerical stability problem.
     :param local_scope: The local scope used for get_numeric_gradient.
@@ -229,9 +229,9 @@ class GradientChecker(unittest.TestCase):
         """Use relative error for the comparison.
 
         :param numeric_grads: the numerical graidents.
-        :type numeric_grads: a list of numpy.array 
+        :type numeric_grads: a list of numpy.array
         :param analytic_grads: the analytical graidents.
-        :type analytic_grads: a list of numpy.array 
+        :type analytic_grads: a list of numpy.array
         :param name: the names of gradients, used to print for debug.
         :type names: a list of string
         :param msg_prefix: string info, used to print for debug.
@@ -286,6 +286,9 @@ class GradientChecker(unittest.TestCase):
         for no_grad in no_grad_set:
             if no_grad not in in_names:
                 raise ValueError("no_grad should be in in_names")
+            if no_grad in inputs_to_check:
+                raise ValueError("no_grad should not be in inputs_to_check")
+
         backward_op = core.Operator.backward(forward_op, no_grad_set)
 
         places = [core.CPUPlace()]
@@ -301,7 +304,6 @@ class GradientChecker(unittest.TestCase):
 
         check_names = [grad_var_name(name) for name in inputs_to_check]
         for place in places:
-            # get analytical gradients according to different device
             analytic_grads = self.__get_gradient(forward_op, backward_op,
                                                  input_vars, check_names, place)
             self.__assert_is_close(numeric_grads, analytic_grads, check_names,
diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py
index 3bc05a0fec..a4899355b5 100644
--- a/python/paddle/v2/framework/tests/op_test_util.py
+++ b/python/paddle/v2/framework/tests/op_test_util.py
@@ -6,13 +6,13 @@ from paddle.v2.framework.op import Operator
 class OpTestMeta(type):
     """
     Operator Test ClassMeta.
-    
-    It injects `test_all` method into user's OperatorTest class, to make Python 
+
+    It injects `test_all` method into user's OperatorTest class, to make Python
     unittest module run that method.
-    
+
     The `test_all` read what value is stored in `self`. It use self's values to
     create and run a operator, and check whether that op is OK or not.
-    
+
     See `test_add_two_op` for example usage.
     """
 
diff --git a/python/paddle/v2/framework/tests/test_cos_sim_op.py b/python/paddle/v2/framework/tests/test_cos_sim_op.py
new file mode 100644
index 0000000000..32013a7999
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_cos_sim_op.py
@@ -0,0 +1,60 @@
+import unittest
+import numpy as np
+from gradient_checker import GradientChecker, create_op
+from op_test_util import OpTestMeta
+
+
+class TestCosSimOp(unittest.TestCase):
+    __metaclass__ = OpTestMeta
+
+    def setUp(self):
+        self.type = "cos_sim"
+        self.inputs = {
+            'X': np.random.random((32, 64)).astype("float32"),
+            'Y': np.random.random((32, 64)).astype("float32")
+        }
+        expect_x_norm = np.linalg.norm(self.inputs['X'], axis=1)
+        expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=1)
+        expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=1) / \
+            expect_x_norm / expect_y_norm
+        self.outputs = {
+            'XNorm': np.expand_dims(expect_x_norm, 1),
+            'YNorm': np.expand_dims(expect_y_norm, 1),
+            'Out': np.expand_dims(expect_out, 1)
+        }
+
+
+class TestCosSimGradOp(GradientChecker):
+    def setUp(self):
+        self.op = create_op("cos_sim")
+        self.inputs = {
+            'X': np.random.random((10, 5)).astype("float32"),
+            'Y': np.random.random((10, 5)).astype("float32")
+        }
+
+    def test_cpu_gpu_compare(self):
+        self.compare_grad(self.op, self.inputs)
+
+    def test_normal(self):
+        self.check_grad(
+            self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.05)
+
+    def test_ignore_x(self):
+        self.check_grad(
+            self.op,
+            self.inputs, ["Y"],
+            "Out",
+            max_relative_error=0.05,
+            no_grad_set={"X"})
+
+    def test_ignore_y(self):
+        self.check_grad(
+            self.op,
+            self.inputs, ["X"],
+            "Out",
+            max_relative_error=0.05,
+            no_grad_set={"Y"})
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py
index ee0d81a64e..b58e4266d1 100644
--- a/python/paddle/v2/framework/tests/test_mul_op.py
+++ b/python/paddle/v2/framework/tests/test_mul_op.py
@@ -16,16 +16,37 @@ class TestMulOp(unittest.TestCase):
         self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
 
 
-class MulGradOpTest(GradientChecker):
-    def test_mul(self):
-        op = create_op("mul")
-        inputs = {
+class TestMulGradOp(GradientChecker):
+    def setUp(self):
+        self.op = create_op("mul")
+        self.inputs = {
             'X': np.random.random((32, 84)).astype("float32"),
             'Y': np.random.random((84, 100)).astype("float32")
         }
+
+    def test_cpu_gpu_compare(self):
+        self.compare_grad(self.op, self.inputs)
+
+    def test_normal(self):
         # mul op will enlarge the relative error
         self.check_grad(
-            op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5)
+            self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
+
+    def test_ignore_x(self):
+        self.check_grad(
+            self.op,
+            self.inputs, ["Y"],
+            "Out",
+            max_relative_error=0.5,
+            no_grad_set={"X"})
+
+    def test_ignore_y(self):
+        self.check_grad(
+            self.op,
+            self.inputs, ["X"],
+            "Out",
+            max_relative_error=0.5,
+            no_grad_set={"Y"})
 
 
 # TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py
index 45d569da29..2ddb85e2e7 100644
--- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py
+++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py
@@ -16,14 +16,22 @@ class TestRowwiseAddOp(unittest.TestCase):
         self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
 
 
-class RowwiseAddGradOpTest(GradientChecker):
-    def test_rowwise_add(self):
-        op = create_op("rowwise_add")
-        inputs = {
+class TestRowwiseAddGradOp(GradientChecker):
+    def setUp(self):
+        self.op = create_op("rowwise_add")
+        self.inputs = {
             "X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"),
             "b": np.random.uniform(0.1, 1, [10]).astype("float32")
         }
-        self.check_grad(op, inputs, set(["X", "b"]), "Out")
+
+    def test_normal(self):
+        self.check_grad(self.op, self.inputs, ["X", "b"], "Out")
+
+    def test_ignore_b(self):
+        self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"})
+
+    def test_ignore_x(self):
+        self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"})
 
 
 if __name__ == '__main__':