From 99e43d1d07aa6dc63dba09141529764c7db83198 Mon Sep 17 00:00:00 2001
From: liaogang <liaogang@baidu.com>
Date: Fri, 23 Dec 2016 13:20:42 +0800
Subject: [PATCH 1/5] Add c++11 build python binding package

---
 paddle/api/paddle_ld_flags.py |  7 +++++--
 paddle/setup.py.in            | 11 ++++-------
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/paddle/api/paddle_ld_flags.py b/paddle/api/paddle_ld_flags.py
index 7c8206e3fe..b4d27b1cc7 100644
--- a/paddle/api/paddle_ld_flags.py
+++ b/paddle/api/paddle_ld_flags.py
@@ -141,9 +141,12 @@ try:
 
         def c_flag(self):
             if self.with_coverage:
-                return ["-fprofile-arcs", "-ftest-coverage", "-O0", "-g"]
+                return [
+                    "-fprofile-arcs", "-ftest-coverage", "-O0", "-g",
+                    "-std=c++11"
+                ]
             else:
-                return None
+                return ["-std=c++11"]
 except ImportError:
 
     class PaddleLDFlag(object):
diff --git a/paddle/setup.py.in b/paddle/setup.py.in
index b4c38a41b8..464ad63286 100644
--- a/paddle/setup.py.in
+++ b/paddle/setup.py.in
@@ -30,8 +30,10 @@ is_lin = (system == 'linux')
 # The extra links will passed from COMAKE
 #   because generate paddle LDFLAGS is too complicated to do in setup.py
 #   it just read COMAKE generated LDFLAGS.
+extra_comps = []
 extra_links = []
 obj = api.paddle_ld_flags.PaddleLDFlag()
+extra_comps = obj.c_flag()
 ldflags = obj.ldflag_str()
 if ldflags is not None:
   extra_links.extend(ldflags.split(" "))
@@ -51,20 +53,15 @@ elif is_osx == True:
 
 include_dirs = [np.get_include(), "../"]    # include numpy and paddle.
 
-extra_c = obj.c_flag()
-
-attr=dict()
-if extra_c is not None:
-  attr["extra_compile_args"] = extra_c
-
 setup(name="py_paddle",
   version="@PADDLE_VERSION@",
   ext_modules=[
     Extension('py_paddle._swig_paddle',      # Build SWIG Extension.
        ['Paddle_wrap.cxx'],
+       language = "c++",
        include_dirs = include_dirs,
        extra_link_args = extra_links,
-       **attr
+       extra_compile_args = extra_comps
     )
   ],
   packages=['py_paddle'],

From c8d0791accb7fbceda308756e6271e12e233c063 Mon Sep 17 00:00:00 2001
From: liaogang <liaogang@baidu.com>
Date: Fri, 23 Dec 2016 13:21:48 +0800
Subject: [PATCH 2/5] Add common.h and remove DisableCopy and Typedefs

---
 .../image_classification/index_cn.md          | 205 ++++++++++++++++++
 .../image_classification/index_en.md          |   2 +-
 paddle/api/PaddleAPI.h                        |  34 ++-
 paddle/cuda/include/hl_base.h                 |  66 +++---
 paddle/gserver/dataproviders/DataProvider.h   |   2 +-
 .../gserver/layers/BatchNormalizationLayer.h  |   2 +
 paddle/gserver/layers/GruCompute.h            |   2 +-
 paddle/gserver/layers/LstmCompute.h           |   2 +-
 paddle/gserver/layers/MultinomialSampler.h    |   2 +-
 paddle/math/BaseMatrix.h                      |   2 +-
 paddle/math/Matrix.h                          |   2 +-
 paddle/math/TensorExpression.h                |   2 +-
 paddle/math/Vector.h                          |   2 +-
 paddle/parameter/ParallelParameter.h          |   2 +-
 paddle/parameter/Parameter.h                  |   2 +-
 paddle/parameter/ParameterUpdateFunctions.h   |   2 +-
 paddle/pserver/BaseClient.h                   |   2 +-
 paddle/pserver/ParameterClient2.h             |   2 +-
 paddle/pserver/ParameterServer2.h             |   2 +-
 paddle/utils/CpuId.h                          |   2 +-
 paddle/utils/DisableCopy.h                    |  23 --
 paddle/utils/Locks.h                          |   2 +-
 paddle/utils/Util.h                           |   3 +-
 paddle/utils/Version.h                        |   2 +-
 paddle/utils/{TypeDefs.h => common.h}         |  15 +-
 25 files changed, 277 insertions(+), 107 deletions(-)
 create mode 100644 doc/tutorials/image_classification/index_cn.md
 delete mode 100644 paddle/utils/DisableCopy.h
 rename paddle/utils/{TypeDefs.h => common.h} (71%)

diff --git a/doc/tutorials/image_classification/index_cn.md b/doc/tutorials/image_classification/index_cn.md
new file mode 100644
index 0000000000..87f465522a
--- /dev/null
+++ b/doc/tutorials/image_classification/index_cn.md
@@ -0,0 +1,205 @@
+图像分类教程
+==========
+
+在本教程中,我们将使用CIFAR-10数据集训练一个卷积神经网络,并使用这个神经网络来对图片进行分类。如下图所示,卷积神经网络可以辨识图片中的主体,并给出分类结果。
+<center>![Image Classification](./image_classification.png)</center>
+
+## 数据准备
+首先下载CIFAR-10数据集。下面是CIFAR-10数据集的官方网址:
+
+<https://www.cs.toronto.edu/~kriz/cifar.html>
+
+我们准备了一个脚本,可以用于从官方网站上下载CIFAR-10数据集,转为jpeg文件并存入特定的目录。使用这个脚本前请确认已经安装了pillow及相关依赖模块。可以参照下面的命令进行安装:
+
+1. 安装pillow
+
+```bash
+sudo apt-get install libjpeg-dev
+pip install pillow
+```
+
+2. 下载数据集
+
+```bash
+cd demo/image_classification/data/
+sh download_cifar.sh
+```
+
+CIFAR-10数据集包含60000张32x32的彩色图片。图片分为10类,每个类包含6000张。其中50000张图片作为训练集,10000张作为测试集。
+
+下图展示了所有的图片类别,每个类别中随机抽取了10张图片。
+<center>![Image Classification](./cifar.png)</center>
+
+脚本运行完成后,我们应当会得到一个名为cifar-out的文件夹,其下子文件夹的结构如下
+
+
+```
+train
+---airplane
+---automobile
+---bird
+---cat
+---deer
+---dog
+---frog
+---horse
+---ship
+---truck
+test
+---airplane
+---automobile
+---bird
+---cat
+---deer
+---dog
+---frog
+---horse
+---ship
+---truck
+```
+
+cifar-out下包含`train`和`test`两个文件夹,其中分别包含了CIFAR-10中的训练集和测试集。这两个文件夹下各自有10个子文件夹,每个子文件夹下存储相应分类的图片。将图片按照上述结构存储好之后,我们就可以着手对分类模型进行训练了。
+
+## 预处理
+数据下载之后,还需要进行预处理,将数据转换为Paddle的格式。我们可以通过如下命令进行预处理工作:
+
+```
+cd demo/image_classification/
+sh preprocess.sh
+```
+
+其中`preprocess.sh` 调用 `./demo/image_classification/preprocess.py` 对图片进行预处理
+```sh
+export PYTHONPATH=$PYTHONPATH:../../
+data_dir=./data/cifar-out
+python preprocess.py -i $data_dir -s 32 -c 1
+```
+
+`./demo/image_classification/preprocess.py` 使用如下参数:
+
+- `-i` 或 `--input` 给出输入数据所在路径;
+- `-s` 或 `--size` 给出图片尺寸;
+- `-c` 或 `--color` 标示图片是彩色图或灰度图
+
+## 模型训练
+在开始训练之前,我们需要先创建一个模型配置文件。下面我们给出了一个配置示例。**注意**,这里的列出的和`vgg_16_cifar.py`文件稍有差别,因为该文件可适用于预测。
+
+```python
+from paddle.trainer_config_helpers import *
+data_dir='data/cifar-out/batches/'
+meta_path=data_dir+'batches.meta'
+args = {'meta':meta_path, 'mean_img_size': 32,
+        'img_size': 32, 'num_classes': 10,
+        'use_jpeg': 1, 'color': "color"}
+define_py_data_sources2(train_list=data_dir+"train.list",
+                        test_list=data_dir+'test.list',
+                        module='image_provider',
+                        obj='processData',
+                        args=args)
+settings(
+    batch_size = 128,
+    learning_rate = 0.1 / 128.0,
+    learning_method = MomentumOptimizer(0.9),
+    regularization = L2Regularization(0.0005 * 128))
+
+img = data_layer(name='image', size=3*32*32)
+lbl = data_layer(name="label", size=10)
+# small_vgg is predined in trainer_config_helpers.network
+predict = small_vgg(input_image=img, num_channels=3)
+outputs(classification_cost(input=predict, label=lbl))
+```
+
+在第一行中我们载入用于定义网络的函数。
+```python
+from paddle.trainer_config_helpers import *
+```
+
+之后定义的`define_py_data_sources2`使用Python数据提供器,其中 `args`将在`image_provider.py`进行使用,该文件负责产生图片数据并传递给Paddle系统
+ - `meta`: 训练集平均值。
+ - `mean_img_size`: 平均特征图的高度及宽度。
+ - `img_size`:输入图片的高度及宽度。
+ - `num_classes`:类别个数。
+ - `use_jpeg`:处理过程中数据存储格式。
+ - `color`:标示是否为彩色图片。
+ 
+ `settings`用于设置训练算法。在下面的例子中,learning rate被设置为0.1除以batch size,而weight decay则为0.0005乘以batch size。
+ 
+ ```python
+settings(
+    batch_size = 128,
+    learning_rate = 0.1 / 128.0,
+    learning_method = MomentumOptimizer(0.9),
+    regularization = L2Regularization(0.0005 * 128)
+)
+```
+
+`small_vgg`定义了网络结构。这里我们使用的是一个小的VGG网络。关于VGG卷积神经网络的描述可以参考:[http://www.robots.ox.ac.uk/~vgg/research/very_deep/](http://www.robots.ox.ac.uk/~vgg/research/very_deep/)。
+```python
+# small_vgg is predined in trainer_config_helpers.network
+predict = small_vgg(input_image=img, num_channels=3)
+```
+配置创建完毕后,可以运行脚本train.sh来训练模型。
+
+```bash
+config=vgg_16_cifar.py
+output=./cifar_vgg_model
+log=train.log
+
+paddle train \
+--config=$config \
+--dot_period=10 \
+--log_period=100 \
+--test_all_data_in_one_period=1 \
+--use_gpu=1 \
+--save_dir=$output \
+2>&1 | tee $log
+
+python -m paddle.utils.plotcurve -i $log > plot.png
+```
+- 这里我们使用的是GPU模式进行训练。如果你没有GPU环境,可以设置`use_gpu=0`。
+- `./demo/image_classification/vgg_16_cifar.py`是网络和数据配置文件。各项参数的详细说明可以在命令行参数相关文档中找到。
+- 脚本`plotcurve.py`依赖于python的`matplotlib`模块。因此如果这个脚本运行失败,也许是因为需要安装`matplotlib`。
+在训练完成后,训练及测试误差曲线图会被`plotcurve.py`脚本保存在 `plot.png`中。下面是一个误差曲线图的示例:
+
+<center>![Training and testing curves.](./plot.png)</center>
+
+## 预测
+在训练完成后,模型及参数会被保存在路径`./cifar_vgg_model/pass-%05d`下。例如第300个pass的模型会被保存在`./cifar_vgg_model/pass-00299`。
+
+要对一个图片的进行分类预测,我们可以使用`predict.sh`,该脚本将输出预测分类的标签:
+
+```
+sh predict.sh
+```
+
+predict.sh:
+```
+model=cifar_vgg_model/pass-00299/
+image=data/cifar-out/test/airplane/seaplane_s_000978.png
+use_gpu=1
+python prediction.py $model $image $use_gpu
+```
+
+## 练习
+在CUB-200数据集上使用VGG模型训练一个鸟类图片分类模型。相关的鸟类数据集可以从如下地址下载,其中包含了200种鸟类的照片(主要来自北美洲)。
+
+<http://www.vision.caltech.edu/visipedia/CUB-200.html>
+
+
+
+
+## 细节探究
+### 卷积神经网络
+卷积神经网络是一种使用卷积层的前向神经网络,很适合构建用于理解图片内容的模型。一个典型的神经网络如下图所示:
+
+![Convolutional Neural Network](./lenet.png)
+
+一个卷积神经网络包含如下层:
+
+- 卷积层:通过卷积操作从图片或特征图中提取特征
+- 池化层:使用max-pooling对特征图下采样
+- 全连接层:使输入层到隐藏层的神经元是全部连接的。
+
+卷积神经网络在图片分类上有着惊人的性能,这是因为它发掘出了图片的两类重要信息:局部关联性质和空间不变性质。通过交替使用卷积和池化处理, 卷积神经网络能够很好的表示这两类信息。
+
+关于如何定义网络中的层,以及如何在层之间进行连接,请参考Layer文档。
diff --git a/doc/tutorials/image_classification/index_en.md b/doc/tutorials/image_classification/index_en.md
index 29cfc99702..60c81a6a53 100644
--- a/doc/tutorials/image_classification/index_en.md
+++ b/doc/tutorials/image_classification/index_en.md
@@ -147,7 +147,7 @@ for classification. A description of VGG network can be found here [http://www.r
 # small_vgg is predined in trainer_config_helpers.network
 predict = small_vgg(input_image=img, num_channels=3)
 ```
-After writing the config, we can train the model by running the script train.sh. Notice that the following script assumes the you run the script in the `./demo/image_classification` folder. If you run the script in a different folder, you need to change the paths of the scripts and the configuration files accordingly.
+After writing the config, we can train the model by running the script train.sh.
 
 ```bash
 config=vgg_16_cifar.py
diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h
index 84a66719c3..5c4c25e770 100644
--- a/paddle/api/PaddleAPI.h
+++ b/paddle/api/PaddleAPI.h
@@ -20,15 +20,11 @@ limitations under the License. */
 #include <string>
 #include <vector>
 #include "paddle/utils/GlobalConstants.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 /// Import PaddlePaddle's enumeration into global namespace.
 using namespace paddle::enumeration_wrapper;  // NOLINT
 
-#define DISABLE_COPY_AND_ASSIGN(classname) \
-  classname(const classname& other);       \
-  classname& operator=(const classname& other)
-
 /**
  * @brief Initialize paddle.
  *
@@ -102,7 +98,7 @@ const size_t NO_SPARSE_ID = -1UL;
 struct MatrixPrivate;
 class Matrix {
   Matrix();  // User Cannot Create Matrix.
-  DISABLE_COPY_AND_ASSIGN(Matrix);
+  DISABLE_COPY(Matrix);
   static Matrix* createByPaddleMatrixPtr(void* sharedPtr);
 
 public:
@@ -242,7 +238,7 @@ private:
 
 struct VectorPrivate;
 class Vector {
-  DISABLE_COPY_AND_ASSIGN(Vector);
+  DISABLE_COPY(Vector);
   Vector();
   static Vector* createByPaddleVectorPtr(void* ptr);
 
@@ -322,7 +318,7 @@ private:
 struct IVectorPrivate;
 class IVector {
   IVector();
-  DISABLE_COPY_AND_ASSIGN(IVector);
+  DISABLE_COPY(IVector);
   static IVector* createByPaddleVectorPtr(void* ptr);
 
 public:
@@ -402,7 +398,7 @@ struct ArgumentsPrivate;
 class Arguments {
 private:
   Arguments();  // Internal Create.
-  DISABLE_COPY_AND_ASSIGN(Arguments);
+  DISABLE_COPY(Arguments);
 
 public:
   /**
@@ -472,7 +468,7 @@ enum GradientMatchineCreateMode {
 
 struct ParameterConfigPrivate;
 class ParameterConfig {
-  DISABLE_COPY_AND_ASSIGN(ParameterConfig);
+  DISABLE_COPY(ParameterConfig);
   ParameterConfig();
 
   /**
@@ -502,7 +498,7 @@ private:
 
 struct OptimizationConfigPrivate;
 class OptimizationConfig {
-  DISABLE_COPY_AND_ASSIGN(OptimizationConfig);
+  DISABLE_COPY(OptimizationConfig);
   OptimizationConfig();
 
 public:
@@ -526,7 +522,7 @@ struct ParameterPrivate;
 class Parameter {
 private:
   Parameter();
-  DISABLE_COPY_AND_ASSIGN(Parameter);
+  DISABLE_COPY(Parameter);
 
 public:
   virtual ~Parameter();
@@ -568,7 +564,7 @@ struct ModelConfigPrivate;
 class ModelConfig {
 private:
   ModelConfig();
-  DISABLE_COPY_AND_ASSIGN(ModelConfig);
+  DISABLE_COPY(ModelConfig);
 
 public:
   virtual ~ModelConfig();
@@ -589,7 +585,7 @@ struct TrainerConfigPrivate;
 class TrainerConfig {
 private:
   TrainerConfig();
-  DISABLE_COPY_AND_ASSIGN(TrainerConfig);
+  DISABLE_COPY(TrainerConfig);
 
 public:
   virtual ~TrainerConfig();
@@ -629,7 +625,7 @@ public:
 
 struct ParameterTraverseCallbackPrivate;
 class ParameterTraverseCallback {
-  DISABLE_COPY_AND_ASSIGN(ParameterTraverseCallback);
+  DISABLE_COPY(ParameterTraverseCallback);
   ParameterTraverseCallback();
 
 public:
@@ -651,7 +647,7 @@ private:
  */
 struct ParameterOptimizerPrivate;
 class ParameterOptimizer {
-  DISABLE_COPY_AND_ASSIGN(ParameterOptimizer);
+  DISABLE_COPY(ParameterOptimizer);
   ParameterOptimizer();
 
 public:
@@ -688,7 +684,7 @@ struct GradientMachinePrivate;
 class GradientMachine {
 private:
   GradientMachine();
-  DISABLE_COPY_AND_ASSIGN(GradientMachine);
+  DISABLE_COPY(GradientMachine);
 
 public:
   virtual ~GradientMachine();
@@ -780,7 +776,7 @@ private:
   TrainerPrivate* m;
   Trainer();
   Trainer(TrainerConfig* optConfig, GradientMachine* gm);
-  DISABLE_COPY_AND_ASSIGN(Trainer);
+  DISABLE_COPY(Trainer);
 
 public:
   virtual ~Trainer();
@@ -846,7 +842,7 @@ public:
 
 struct SequenceGeneratorPrivate;
 class SequenceGenerator {
-  DISABLE_COPY_AND_ASSIGN(SequenceGenerator);
+  DISABLE_COPY(SequenceGenerator);
   SequenceGenerator();
 
 public:
diff --git a/paddle/cuda/include/hl_base.h b/paddle/cuda/include/hl_base.h
index 84c5f2d5c9..5b9884b786 100644
--- a/paddle/cuda/include/hl_base.h
+++ b/paddle/cuda/include/hl_base.h
@@ -16,7 +16,31 @@ limitations under the License. */
 #define HL_BASE_H_
 
 #include <cstddef>
-#include "paddle/utils/TypeDefs.h"
+
+#ifdef PADDLE_TYPE_DOUBLE
+#define HL_FLOAT_MAX 3.40282347e+38F
+#define HL_FLOAT_MIN 1.17549435e-38F
+using real = double;
+#else
+#define HL_FLOAT_MAX 1.7976931348623157e+308
+#define HL_FLOAT_MIN 2.2250738585072014e-308
+using real = float;
+#endif
+
+/**
+ * The maximum input value for exp, used to avoid overflow problem.
+ * currently only used for tanh function.
+ */
+#define EXP_MAX_INPUT 40.0
+
+/**
+ * @brief DIVUP(x, y) is similar to ceil(x / y).
+ * @note  For CUDA, DIVUP will be used to specify
+ *        the size of blockDim.
+ */
+#ifndef DIVUP
+#define DIVUP(x, y) (((x) + (y)-1) / (y))
+#endif
 
 /**
  * HPPL is an internal high performance parallel computing library
@@ -181,46 +205,6 @@ typedef struct {
   size_t nnz;
 } _hl_sparse_matrix_s, *hl_sparse_matrix_s;
 
-#ifndef PADDLE_TYPE_DOUBLE
-/**
- * HPPL data type: real (float or double)
- *
- * if real == float
- *
- * HL_FLOAT_MAX: 3.40282347e+38F
- *
- * HL_FLOAT_MIN: 1.17549435e-38F
- */
-#define HL_FLOAT_MAX 3.40282347e+38F
-/**
- * if real == double
- *
- * HL_FLOAT_MAX: 1.7976931348623157e+308
- *
- * HL_FLOAT_MIN: 2.2250738585072014e-308
- */
-#define HL_FLOAT_MIN 1.17549435e-38F
-#else
-#define HL_FLOAT_MAX 1.7976931348623157e+308
-#define HL_FLOAT_MIN 2.2250738585072014e-308
-#endif
-
-/**
- * The maximum input value for exp, used to avoid overflow problem.
- *
- * Currently only used for tanh function.
- */
-#define EXP_MAX_INPUT 40.0
-
-/**
- * @brief DIVUP(x, y) is similar to ceil(x / y).
- * @note  For CUDA, DIVUP will be used to specify
- *        the size of blockDim.
- */
-#ifndef DIVUP
-#define DIVUP(x, y) (((x) + (y)-1) / (y))
-#endif
-
 #ifdef __NVCC__
 
 #include "cuda_runtime.h"
diff --git a/paddle/gserver/dataproviders/DataProvider.h b/paddle/gserver/dataproviders/DataProvider.h
index 9b7f7e36ce..5f031fc7c0 100644
--- a/paddle/gserver/dataproviders/DataProvider.h
+++ b/paddle/gserver/dataproviders/DataProvider.h
@@ -34,8 +34,8 @@ limitations under the License. */
 #include "paddle/utils/Logging.h"
 #include "paddle/utils/Queue.h"
 #include "paddle/utils/ThreadLocal.h"
-#include "paddle/utils/TypeDefs.h"
 #include "paddle/utils/Util.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 /**
diff --git a/paddle/gserver/layers/BatchNormalizationLayer.h b/paddle/gserver/layers/BatchNormalizationLayer.h
index 052c207732..195acbbfc5 100644
--- a/paddle/gserver/layers/BatchNormalizationLayer.h
+++ b/paddle/gserver/layers/BatchNormalizationLayer.h
@@ -58,6 +58,8 @@ protected:
   /// to batch, channels* imagePixels.
   void shrinkMat(const MatrixPtr& in, MatrixPtr& out);
 
+  void onPassEnd() { firstTest_ = true; }
+
   MatrixPtr tmpMat_, tmpGrad_;
   MatrixPtr expandedIn_, expandedOut_;
   MatrixPtr expandedInGrad_, expandedOutGrad_, inGrad_;
diff --git a/paddle/gserver/layers/GruCompute.h b/paddle/gserver/layers/GruCompute.h
index 42c0019319..a56af21317 100644
--- a/paddle/gserver/layers/GruCompute.h
+++ b/paddle/gserver/layers/GruCompute.h
@@ -16,7 +16,7 @@ limitations under the License. */
 
 #include "ModelConfig.pb.h"
 #include "hl_gpu.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/gserver/layers/LstmCompute.h b/paddle/gserver/layers/LstmCompute.h
index 140a4c6ecf..0d65b4158e 100644
--- a/paddle/gserver/layers/LstmCompute.h
+++ b/paddle/gserver/layers/LstmCompute.h
@@ -16,7 +16,7 @@ limitations under the License. */
 
 #include "ModelConfig.pb.h"
 #include "hl_gpu.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/gserver/layers/MultinomialSampler.h b/paddle/gserver/layers/MultinomialSampler.h
index 677b047029..b48073c80b 100644
--- a/paddle/gserver/layers/MultinomialSampler.h
+++ b/paddle/gserver/layers/MultinomialSampler.h
@@ -16,7 +16,7 @@ limitations under the License. */
 
 #include <memory>
 #include <random>
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/math/BaseMatrix.h b/paddle/math/BaseMatrix.h
index 2933c20fba..8f9bc9e823 100644
--- a/paddle/math/BaseMatrix.h
+++ b/paddle/math/BaseMatrix.h
@@ -16,7 +16,7 @@ limitations under the License. */
 #include <stdint.h>
 #include <cstddef>
 #include "TensorExpression.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h
index 25ce09e346..bda863de38 100644
--- a/paddle/math/Matrix.h
+++ b/paddle/math/Matrix.h
@@ -27,7 +27,7 @@ limitations under the License. */
 #include "MemoryHandle.h"
 #include "Vector.h"
 #include "paddle/utils/ThreadLocal.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/math/TensorExpression.h b/paddle/math/TensorExpression.h
index 9bd789e8c5..f3d60e4003 100644
--- a/paddle/math/TensorExpression.h
+++ b/paddle/math/TensorExpression.h
@@ -17,7 +17,7 @@ limitations under the License. */
 #include <cstddef>
 #include "hl_tensor_ops.h"
 #include "paddle/utils/Logging.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/math/Vector.h b/paddle/math/Vector.h
index 8a24103bd4..b4347a70f8 100644
--- a/paddle/math/Vector.h
+++ b/paddle/math/Vector.h
@@ -22,7 +22,7 @@ limitations under the License. */
 #include "BaseMatrix.h"
 #include "MemoryHandle.h"
 #include "paddle/utils/Thread.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/parameter/ParallelParameter.h b/paddle/parameter/ParallelParameter.h
index 417e386dc7..1ee220d2dc 100644
--- a/paddle/parameter/ParallelParameter.h
+++ b/paddle/parameter/ParallelParameter.h
@@ -28,7 +28,7 @@ limitations under the License. */
 #include "paddle/parameter/ParameterUpdateFunctions.h"
 #include "paddle/utils/Flags.h"
 #include "paddle/utils/Locks.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 #include "ParameterConfig.pb.h"
 
diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h
index 532c6770e5..e05137b315 100644
--- a/paddle/parameter/Parameter.h
+++ b/paddle/parameter/Parameter.h
@@ -29,8 +29,8 @@ limitations under the License. */
 #include "paddle/utils/GlobalConstants.h"
 #include "paddle/utils/Locks.h"
 #include "paddle/utils/ThreadLocal.h"
-#include "paddle/utils/TypeDefs.h"
 #include "paddle/utils/Util.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/parameter/ParameterUpdateFunctions.h b/paddle/parameter/ParameterUpdateFunctions.h
index 2d277e47e7..2cb3798717 100644
--- a/paddle/parameter/ParameterUpdateFunctions.h
+++ b/paddle/parameter/ParameterUpdateFunctions.h
@@ -15,7 +15,7 @@ limitations under the License. */
 #pragma once
 
 #include "paddle/math/Vector.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/pserver/BaseClient.h b/paddle/pserver/BaseClient.h
index 262afafbe2..ccf05ae1ca 100644
--- a/paddle/pserver/BaseClient.h
+++ b/paddle/pserver/BaseClient.h
@@ -18,7 +18,7 @@ limitations under the License. */
 #include "paddle/math/Matrix.h"
 #include "paddle/pserver/ProtoServer.h"
 #include "paddle/utils/Queue.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 namespace paddle {
 
diff --git a/paddle/pserver/ParameterClient2.h b/paddle/pserver/ParameterClient2.h
index eed71ccb43..70cfc6d700 100644
--- a/paddle/pserver/ParameterClient2.h
+++ b/paddle/pserver/ParameterClient2.h
@@ -26,8 +26,8 @@ limitations under the License. */
 #include "paddle/utils/Flags.h"
 #include "paddle/utils/Locks.h"
 #include "paddle/utils/Queue.h"
-#include "paddle/utils/TypeDefs.h"
 #include "paddle/utils/Util.h"
+#include "paddle/utils/common.h"
 
 #include "ParameterService.pb.h"
 
diff --git a/paddle/pserver/ParameterServer2.h b/paddle/pserver/ParameterServer2.h
index b0cf22e1fb..79d1eb97ff 100644
--- a/paddle/pserver/ParameterServer2.h
+++ b/paddle/pserver/ParameterServer2.h
@@ -32,7 +32,7 @@ limitations under the License. */
 #include "paddle/utils/Locks.h"
 #include "paddle/utils/Stat.h"
 #include "paddle/utils/ThreadLocal.h"
-#include "paddle/utils/TypeDefs.h"
+#include "paddle/utils/common.h"
 
 #include "ParameterService.pb.h"
 
diff --git a/paddle/utils/CpuId.h b/paddle/utils/CpuId.h
index 7a354da758..1218e8194c 100644
--- a/paddle/utils/CpuId.h
+++ b/paddle/utils/CpuId.h
@@ -11,7 +11,7 @@ limitations under the License. */
 
 #pragma once
 
-#include "DisableCopy.h"
+#include "common.h"
 
 namespace paddle {
 
diff --git a/paddle/utils/DisableCopy.h b/paddle/utils/DisableCopy.h
deleted file mode 100644
index 41de98bbde..0000000000
--- a/paddle/utils/DisableCopy.h
+++ /dev/null
@@ -1,23 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-
-/**
- * Disable copy macro.
- */
-#define DISABLE_COPY(CLASS_NAME)                \
-  CLASS_NAME(CLASS_NAME &&) = delete;           \
-  CLASS_NAME(const CLASS_NAME &other) = delete; \
-  CLASS_NAME &operator=(const CLASS_NAME &other) = delete
diff --git a/paddle/utils/Locks.h b/paddle/utils/Locks.h
index 0f922f3548..a21872e89e 100644
--- a/paddle/utils/Locks.h
+++ b/paddle/utils/Locks.h
@@ -19,7 +19,7 @@ limitations under the License. */
 #include <condition_variable>
 #include <mutex>
 
-#include "DisableCopy.h"
+#include "common.h"
 
 namespace paddle {
 
diff --git a/paddle/utils/Util.h b/paddle/utils/Util.h
index e5a89070f1..dc15ada586 100644
--- a/paddle/utils/Util.h
+++ b/paddle/utils/Util.h
@@ -26,12 +26,11 @@ limitations under the License. */
 #include <unordered_map>
 #include <vector>
 
-#include "DisableCopy.h"
 #include "Logging.h"
 #include "TrainerConfig.pb.h"
+#include "common.h"
 
 #include "Flags.h"
-#include "TypeDefs.h"
 #include "hl_gpu.h"
 
 /**
diff --git a/paddle/utils/Version.h b/paddle/utils/Version.h
index d1a07d9485..aa5df32438 100644
--- a/paddle/utils/Version.h
+++ b/paddle/utils/Version.h
@@ -15,7 +15,7 @@ limitations under the License. */
 #pragma once
 #include <stddef.h>
 #include <iostream>
-#include "TypeDefs.h"
+#include "common.h"
 
 namespace paddle {
 
diff --git a/paddle/utils/TypeDefs.h b/paddle/utils/common.h
similarity index 71%
rename from paddle/utils/TypeDefs.h
rename to paddle/utils/common.h
index c50a05e82d..3ff0b86947 100644
--- a/paddle/utils/TypeDefs.h
+++ b/paddle/utils/common.h
@@ -15,12 +15,19 @@ limitations under the License. */
 #pragma once
 
 namespace paddle {
+
+/**
+ * Disable copy macro.
+ */
+#define DISABLE_COPY(class_name)                \
+  class_name(class_name &&) = delete;           \
+  class_name(const class_name &other) = delete; \
+  class_name &operator=(const class_name &other) = delete
+
 #ifdef PADDLE_TYPE_DOUBLE
-typedef double real;
+using real = double;
 #else
-typedef float real;
+using real = float;
 #endif
 
 }  // namespace paddle
-
-using paddle::real;

From 224e5fcc77306705260d8f54f2994706cd8ee0ef Mon Sep 17 00:00:00 2001
From: wangyanfei01 <wangyanfei01@baidu.com>
Date: Sun, 25 Dec 2016 11:35:49 +0800
Subject: [PATCH 3/5] fix bug:  * gradient_clipping_threshold should be allowed
 to set with parameter-grain

---
 python/paddle/trainer_config_helpers/attrs.py | 40 ++++++++++++-------
 1 file changed, 25 insertions(+), 15 deletions(-)

diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py
index 59bb18bfca..bf02088346 100644
--- a/python/paddle/trainer_config_helpers/attrs.py
+++ b/python/paddle/trainer_config_helpers/attrs.py
@@ -19,34 +19,34 @@ __all__ = [
 
 
 def convert_and_compare(x, Type):
-    """                                                                                                                                                                                                
-    Convert x to be the same type as Type and then convert back to                                                                                                                                      
-    check whether there is a loss of information                                                                                                                                                        
-    :param x: object to be checked                                                                                                                                                                      
-    :param Type: target type to check x over                                                                                                                                                           
-    
+    """
+    Convert x to be the same type as Type and then convert back to
+    check whether there is a loss of information
+    :param x: object to be checked
+    :param Type: target type to check x over
+
     """
     return type(x)(Type(x)) == x
 
 
 def is_compatible_with(x, Type):
-    """                                                                                                                                                                                                
-    Check if x has a type compatible with Type                                                                                                                                                         
-    :param x: object to be checked                                                                                                                                                                     
-    :param Type: target type to check x over                                                                                                                                                           
-    
+    """
+    Check if x has a type compatible with Type
+    :param x: object to be checked
+    :param Type: target type to check x over
+
     """
     if type(x) == Type:
         return True
     try:
         if float == Type or int == Type:
-            # avoid those types that can be converted to float/int but not very                                                                                                                            
-            # meaningful and  could potentially lead to error                                                                                                                                              
-            # i.e., str and bool typed value should not be used for initializing float/int variable                                                                                                        
+            # avoid those types that can be converted to float/int but not very
+            # meaningful and  could potentially lead to error
+            # i.e., str and bool typed value should not be used for initializing float/int variable
             if not isinstance(x, str) and not isinstance(x, bool):
                 return convert_and_compare(x, Type)
         elif bool == Type:
-            # should not use string type to initialize bool variable                                                                                                                                   
+            # should not use string type to initialize bool variable
             if not isinstance(x, str):
                 return convert_and_compare(x, Type)
         else:
@@ -88,6 +88,10 @@ class ParameterAttribute(object):
     :type learning_rate: float or None
     :param momentum: The parameter momentum. None means use global value.
     :type momentum: float or None
+    :param gradient_clipping_threshold: gradient clipping threshold. If gradient
+                                        value larger than some value, will be
+                                        clipped.
+    :type gradient_clipping_threshold: float
     :param sparse_update: Enable sparse update for this parameter. It will
                           enable both local and remote sparse update.
     :type sparse_update: bool
@@ -104,6 +108,7 @@ class ParameterAttribute(object):
                  l2_rate=None,
                  learning_rate=None,
                  momentum=None,
+                 gradient_clipping_threshold=None,
                  sparse_update=False):
         # initialize strategy.
         if is_static:
@@ -152,6 +157,11 @@ class ParameterAttribute(object):
             self.attr['sparse_update'] = True
             self.attr['sparse_remote_update'] = True
 
+        if gradient_clipping_threshold is not None and \
+                is_compatible_with(gradient_clipping_threshold, float):
+            self.attr['gradient_clipping_threshold'] = \
+                gradient_clipping_threshold
+
     def set_default_parameter_name(self, name):
         """
         Set default parameter name. If parameter not set, then will use default

From 027aaf9ef26ad89b33a2d094cb8196926a911cc2 Mon Sep 17 00:00:00 2001
From: qiaolongfei <qiaolongfei@baidu.com>
Date: Sun, 25 Dec 2016 19:32:36 +0800
Subject: [PATCH 4/5] add cluster train for quick_start

---
 demo/quick_start/api_predict.sh           |  2 +-
 demo/quick_start/cluster/cluster_train.sh | 44 +++++++++++++++++++++++
 demo/quick_start/cluster/env.sh           | 28 +++++++++++++++
 demo/quick_start/cluster/pserver.sh       | 26 ++++++++++++++
 paddle/trainer/ThreadParameterUpdater.h   |  4 +--
 5 files changed, 101 insertions(+), 3 deletions(-)
 create mode 100755 demo/quick_start/cluster/cluster_train.sh
 create mode 100644 demo/quick_start/cluster/env.sh
 create mode 100755 demo/quick_start/cluster/pserver.sh

diff --git a/demo/quick_start/api_predict.sh b/demo/quick_start/api_predict.sh
index c90d3b7054..4d9aa9e885 100755
--- a/demo/quick_start/api_predict.sh
+++ b/demo/quick_start/api_predict.sh
@@ -17,7 +17,7 @@ set -e
 #Note the default model is pass-00002, you shold make sure the model path
 #exists or change the mode path.
 #only test on trainer_config.lr.py
-model=output/pass-00001/
+model=output/model/pass-00001/
 config=trainer_config.lr.py
 label=data/labels.list
 dict=data/dict.txt
diff --git a/demo/quick_start/cluster/cluster_train.sh b/demo/quick_start/cluster/cluster_train.sh
new file mode 100755
index 0000000000..aac9b89b14
--- /dev/null
+++ b/demo/quick_start/cluster/cluster_train.sh
@@ -0,0 +1,44 @@
+#!/bin/bash
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -e
+
+# Should run pserver.sh before run this script.
+bin_dir=$(cd `dirname $0`; pwd)
+home_dir=$(cd "${bin_dir}/.."; pwd)
+source "$bin_dir/env.sh"
+
+model_dir="$bin_dir/output"
+log_file="$bin_dir/train.log"
+
+pushd "$home_dir"
+cfg=trainer_config.lr.py
+paddle train \
+  --config=$cfg \
+  --save_dir=${model_dir} \
+  --trainer_count=4 \
+  --local=0 \
+  --log_period=100 \
+  --num_passes=15 \
+  --use_gpu=false \
+  --show_parameter_stats_period=100 \
+  --test_all_data_in_one_period=1 \
+  --num_gradient_servers=1 \
+  --nics=`get_nics` \
+  --port=7164 \
+  --ports_num=1 \
+  --pservers="127.0.0.1" \
+  --comment="paddle_trainer" \
+  2>&1 | tee "$log_file"
+popd
diff --git a/demo/quick_start/cluster/env.sh b/demo/quick_start/cluster/env.sh
new file mode 100644
index 0000000000..a404993835
--- /dev/null
+++ b/demo/quick_start/cluster/env.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -e
+
+function get_nics() {
+  machine=`uname -s`
+  local nics=""
+  if [ "$machine" == "Linux" ]; then
+    nics="lo"
+  elif [ "$machine" == "Darwin" ]; then
+    nics="lo0"
+  else
+    nics="unsupport"
+  fi
+  echo $nics
+}
diff --git a/demo/quick_start/cluster/pserver.sh b/demo/quick_start/cluster/pserver.sh
new file mode 100755
index 0000000000..b187c1d9b9
--- /dev/null
+++ b/demo/quick_start/cluster/pserver.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+set -e
+bin_dir=$(cd `dirname $0`; pwd)
+source "$bin_dir/env.sh"
+
+paddle pserver \
+  --nics=`get_nics` \
+  --port=7164 \
+  --ports_num=1 \
+  --ports_num_for_sparse=1 \
+  --num_gradient_servers=1 \
+  --comment="paddle_pserver" \
+  2>&1 | tee 'pserver.log'
diff --git a/paddle/trainer/ThreadParameterUpdater.h b/paddle/trainer/ThreadParameterUpdater.h
index 880f1f9ddc..bc08a9e9f0 100644
--- a/paddle/trainer/ThreadParameterUpdater.h
+++ b/paddle/trainer/ThreadParameterUpdater.h
@@ -33,8 +33,8 @@ namespace paddle {
    because at the current moment, the merging on CPU is happening on the
    main thread, and the its parameter size can be much larger than the one GPU.
    Thus, for GPU, the parameter updates happens in updateImpl() function, which
-   is called by gradient machines as a callback function as a callback function
-   supplied to backward() and forwardBackward().
+   is called by gradient machines as a callback function supplied to backward()
+   and forwardBackward().
    For CPU, the parameter updates happens in separate threads maintained by this
    class.
  */

From 685299c3c54bb8fc10c3d38cb26445d899f32c5d Mon Sep 17 00:00:00 2001
From: Yu Yang <yuyang18@baidu.com>
Date: Mon, 26 Dec 2016 14:38:31 +0800
Subject: [PATCH 5/5] Rename math.py to layer_math.py

* Fix #903
---
 python/paddle/trainer_config_helpers/__init__.py              | 4 +---
 .../paddle/trainer_config_helpers/{math.py => layer_math.py}  | 0
 2 files changed, 1 insertion(+), 3 deletions(-)
 rename python/paddle/trainer_config_helpers/{math.py => layer_math.py} (100%)

diff --git a/python/paddle/trainer_config_helpers/__init__.py b/python/paddle/trainer_config_helpers/__init__.py
index a2335768b9..0ff5edf825 100644
--- a/python/paddle/trainer_config_helpers/__init__.py
+++ b/python/paddle/trainer_config_helpers/__init__.py
@@ -20,6 +20,4 @@ from layers import *
 from networks import *
 from optimizers import *
 from attrs import *
-
-# This will enable operator overload for LayerOutput
-import math as layer_math
+import layer_math
diff --git a/python/paddle/trainer_config_helpers/math.py b/python/paddle/trainer_config_helpers/layer_math.py
similarity index 100%
rename from python/paddle/trainer_config_helpers/math.py
rename to python/paddle/trainer_config_helpers/layer_math.py