From 734e87e55b00418aed0fac5a879b2704d62cf3ab Mon Sep 17 00:00:00 2001 From: yangyaming Date: Fri, 15 Dec 2017 20:08:55 +0800 Subject: [PATCH 01/11] Add python wrapper for lstm unit op. --- doc/api/v2/fluid/layers.rst | 11 +- python/paddle/v2/fluid/layers/nn.py | 112 +++++++++++++++++++- python/paddle/v2/fluid/tests/test_layers.py | 17 +++ 3 files changed, 132 insertions(+), 8 deletions(-) diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index 89e5fec13b..0ab36402fa 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -188,12 +188,6 @@ beam_search_decode :noindex: -lstm ---------- -.. autofunction:: paddle.v2.fluid.layers.lstm - :noindex: - - lod_rank_table --------- .. autofunction:: paddle.v2.fluid.layers.lod_rank_table @@ -300,3 +294,8 @@ conv2d_transpose .. autofunction:: paddle.v2.fluid.layers.conv2d_transpose :noindex: + +lstm_unit +--------- +.. autofunction:: paddle.v2.fluid.layers.lstm_unit + :noindex: diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index bad7dbd84e..84e62d988c 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -5,12 +5,13 @@ All layers just related to the neural network. from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable +from tensor import concat __all__ = [ 'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy', 'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d', - 'batch_norm', 'beam_search_decode', 'conv2d_transpose' + 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'lstm_unit' ] @@ -392,7 +393,7 @@ def chunk_eval(input, excluded_chunk_types=None, **kwargs): """ - This function computes and outputs the precision, recall and + This function computes and outputs the precision, recall and F1-score of chunk detection. """ helper = LayerHelper("chunk_eval", **kwargs) @@ -789,3 +790,110 @@ def conv2d_transpose(input, attrs=op_attr) return out + + +def lstm_unit(x_t, + hidden_t_prev, + cell_t_prev, + forget_bias=0.0, + main_program=None, + startup_program=None): + """Lstm unit layer. The equation of a lstm step is: + + .. math:: + + i_t & = \sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i) + + f_t & = \sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f) + + c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c) + + o_t & = \sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o) + + h_t & = o_t tanh(c_t) + + The inputs of lstm unit includes :math:`x_t`, :math:`h_{t-1}` and + :math:`c_{t-1}`. The implementation separates the linear transformation + and non-linear transformation apart. Here, we take :math:`i_t` as an + example. The linear transformation is applied by calling a `fc` layer and + the equation is: + + .. math:: + + L_{i_t} = W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i + + The non-linear transformation is applied by calling `lstm_unit_op` and the + equation is: + + .. math:: + + i_t = \sigma(L_{i_t}) + + This layer has two outputs including :math:`o_t` and :math:`h_t`. + + Args: + x_t (Variable): The input value of current step. + hidden_t_prev (Variable): The hidden value of lstm unit. + cell_t_prev (Variable): The cell value of lstm unit. + forget_bias (float): The forget bias of lstm unit. + main_program (Program): The main program. + startup_program (Program): the startup program. + + Returns: + tuple: The cell value and hidden value of lstm unit. + + Raises: + ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\ + not be 2 or the 1st dimensions of **x_t**, **hidden_t_prev** \ + and **cell_t_prev** not be the same. + + Examples: + + .. code-block:: python + + x_t = fluid.layers.fc(input=x_t_data, size=10) + prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=20) + prev_cell = fluid.layers.fc(input=prev_cell_data, size=30) + cell_value, hidden_value = fluid.layers.lstm_unit(x_t=x_t, + hidden_t_prev=prev_hidden, + cell_t_prev=prev_cell) + """ + helper = LayerHelper('lstm_unit', **locals()) + + if len(x_t.shape) != 2: + raise ValueError("Rank of x_t must be 2.") + + if len(hidden_t_prev.shape) != 2: + raise ValueError("Rank of hidden_t_prev must be 2.") + + if len(cell_t_prev.shape) != 2: + raise ValueError("Rank of cell_t_prev must be 2.") + + if x_t.shape[0] != hidden_t_prev.shape[0] or x_t.shape[ + 0] != cell_t_prev.shape[0]: + raise ValueError("The 1s dimension of x_t, hidden_t_prev and " + "cell_t_prev must be the same.") + + size = cell_t_prev.shape[1] + concat_out = concat( + input=[x_t, hidden_t_prev], + axis=1, + main_program=main_program, + startup_program=startup_program) + fc_out = fc(input=concat_out, + size=4 * size, + main_program=main_program, + startup_program=startup_program) + dtype = x_t.dtype + c = helper.create_tmp_variable(dtype) + h = helper.create_tmp_variable(dtype) + + helper.append_op( + type='lstm_unit', + inputs={"X": fc_out, + "C_prev": cell_t_prev}, + outputs={"C": c, + "H": h}, + attrs={"forget_bias": forget_bias}) + + return c, h diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 9b88080158..468bd41285 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -161,6 +161,23 @@ class TestBook(unittest.TestCase): x=dat, label=lbl)) print(str(program)) + def test_lstm_unit(self): + program = Program() + with program_guard(program): + x_t_data = layers.data( + name='x_t_data', shape=[10, 10], dtype='float32') + x_t = layers.fc(input=x_t_data, size=10) + prev_hidden_data = layers.data( + name='prev_hidden_data', shape=[10, 20], dtype='float32') + prev_hidden = layers.fc(input=prev_hidden_data, size=20) + prev_cell_data = layers.data( + name='prev_cell', shape=[10, 30], dtype='float32') + prev_cell = layers.fc(input=prev_cell_data, size=30) + self.assertIsNotNone( + layers.lstm_unit( + x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell)) + print(str(program)) + if __name__ == '__main__': unittest.main() From a398e25d6ac786e14aa18be79438b8d2d1b191d0 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 18 Dec 2017 20:09:36 +0800 Subject: [PATCH 02/11] Expose param_attr and bias_attr. --- paddle/operators/lstm_unit_op.cc | 5 ++++- python/paddle/v2/fluid/layers/nn.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index 18b9cdf2a3..b6eb33bafe 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -51,7 +51,10 @@ class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { LstmUnitOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "FC input before the non-linear activation."); + AddInput("X", + "Lstm unit only applies non-linear activations, please make sure" + "that linear tranformation has already been applied to `X`. " + "Linear tranformation can be applied by adding a `fc` layer"); AddInput( "C_prev", "The cell state tensor of last time-step in the Lstm Unit operator."); diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 84e62d988c..1c101c62c2 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -5,6 +5,7 @@ All layers just related to the neural network. from ..layer_helper import LayerHelper from ..initializer import Normal, Constant from ..framework import Variable +from ..param_attr import ParamAttr from tensor import concat __all__ = [ @@ -796,6 +797,8 @@ def lstm_unit(x_t, hidden_t_prev, cell_t_prev, forget_bias=0.0, + param_attr=None, + bias_attr=ParamAttr(), main_program=None, startup_program=None): """Lstm unit layer. The equation of a lstm step is: @@ -836,6 +839,10 @@ def lstm_unit(x_t, hidden_t_prev (Variable): The hidden value of lstm unit. cell_t_prev (Variable): The cell value of lstm unit. forget_bias (float): The forget bias of lstm unit. + param_attr (ParamAttr): The attributes of parameter weights, used to set + initializer, name etc. + bias_attr (ParamAttr): The attributes of bias weights, used to set + initializer, name etc. main_program (Program): The main program. startup_program (Program): the startup program. @@ -882,6 +889,8 @@ def lstm_unit(x_t, startup_program=startup_program) fc_out = fc(input=concat_out, size=4 * size, + param_attr=param_attr, + bias_attr=bias_attr, main_program=main_program, startup_program=startup_program) dtype = x_t.dtype From 58d6946c874bbe539ace4fde05e7fb4693f30ca1 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 19 Dec 2017 11:03:20 +0800 Subject: [PATCH 03/11] Set the act to 'linear'. --- python/paddle/v2/fluid/layers/nn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 1c101c62c2..ab443826bd 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -891,6 +891,7 @@ def lstm_unit(x_t, size=4 * size, param_attr=param_attr, bias_attr=bias_attr, + act='linear', main_program=main_program, startup_program=startup_program) dtype = x_t.dtype From d993a4f58b7e2be4a76fda406e964229edff2dcb Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 19 Dec 2017 11:19:24 +0800 Subject: [PATCH 04/11] Change default value for bias_attr. --- python/paddle/v2/fluid/layers/nn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 9728adba73..31a0a312db 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -866,7 +866,7 @@ def lstm_unit(x_t, cell_t_prev, forget_bias=0.0, param_attr=None, - bias_attr=ParamAttr(), + bias_attr=None, main_program=None, startup_program=None): """Lstm unit layer. The equation of a lstm step is: @@ -909,8 +909,8 @@ def lstm_unit(x_t, forget_bias (float): The forget bias of lstm unit. param_attr (ParamAttr): The attributes of parameter weights, used to set initializer, name etc. - bias_attr (ParamAttr): The attributes of bias weights, used to set - initializer, name etc. + bias_attr (ParamAttr): The attributes of bias weights, if not False, + bias weights will be created and be set to default value. main_program (Program): The main program. startup_program (Program): the startup program. @@ -949,6 +949,9 @@ def lstm_unit(x_t, raise ValueError("The 1s dimension of x_t, hidden_t_prev and " "cell_t_prev must be the same.") + if bias_attr is None: + bias_attr = ParamAttr() + size = cell_t_prev.shape[1] concat_out = concat( input=[x_t, hidden_t_prev], From 9ee9fefd2de46f2383309f489033fc6d94cd8628 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 19 Dec 2017 11:27:35 +0800 Subject: [PATCH 05/11] Change the return order to h, c. --- python/paddle/v2/fluid/layers/nn.py | 8 ++++---- python/paddle/v2/fluid/tests/test_layers.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 31a0a312db..dd6bb54599 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -900,7 +900,7 @@ def lstm_unit(x_t, i_t = \sigma(L_{i_t}) - This layer has two outputs including :math:`o_t` and :math:`h_t`. + This layer has two outputs including :math:`h_t` and :math:`o_t`. Args: x_t (Variable): The input value of current step. @@ -915,7 +915,7 @@ def lstm_unit(x_t, startup_program (Program): the startup program. Returns: - tuple: The cell value and hidden value of lstm unit. + tuple: The hidden value and cell value of lstm unit. Raises: ValueError: The ranks of **x_t**, **hidden_t_prev** and **cell_t_prev**\ @@ -929,7 +929,7 @@ def lstm_unit(x_t, x_t = fluid.layers.fc(input=x_t_data, size=10) prev_hidden = fluid.layers.fc(input=prev_hidden_data, size=20) prev_cell = fluid.layers.fc(input=prev_cell_data, size=30) - cell_value, hidden_value = fluid.layers.lstm_unit(x_t=x_t, + hidden_value, cell_value = fluid.layers.lstm_unit(x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell) """ @@ -977,4 +977,4 @@ def lstm_unit(x_t, "H": h}, attrs={"forget_bias": forget_bias}) - return c, h + return h, c diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 7b56ae464c..d4a95bf6fc 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -161,7 +161,7 @@ class TestBook(unittest.TestCase): x=dat, label=lbl)) print(str(program)) - def test_seq_expand(self): + def test_sequence_expand(self): program = Program() with program_guard(program): x = layers.data(name='x', shape=[10], dtype='float32') From fa5cdd8f74cecac9d5350a544aa1ea1de73772bd Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 19 Dec 2017 11:47:43 +0800 Subject: [PATCH 06/11] Expose sequence_softmax_op. --- doc/api/v2/fluid/layers.rst | 7 +++++++ python/paddle/v2/fluid/layers/ops.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index 9f3669e115..cf4bf4afd2 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -304,3 +304,10 @@ sequence_expand --------- .. autofunction:: paddle.v2.fluid.layers.sequence_expand :noindex: + + +sequence_softmax +--------- +.. autofunction:: paddle.v2.fluid.layers.sequence_softmax + :noindex: + diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index fa312ace60..d2ff6841a3 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -2,7 +2,7 @@ from ..registry import register_layer __all__ = [ 'mean', 'mul', 'dropout', 'reshape', 'sigmoid', 'scale', 'transpose', 'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div', - 'elementwise_sub', 'elementwise_mul', 'clip', 'abs' + 'elementwise_sub', 'elementwise_mul', 'clip', 'abs', 'sequence_softmax' ] for _OP in set(__all__): From 9573256f9d802dfe1daf9f6887044931ff03f636 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 19 Dec 2017 13:24:12 +0800 Subject: [PATCH 07/11] Remove main_program and startup_program. --- python/paddle/v2/fluid/layers/nn.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 1d03f357eb..2c38c23224 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -764,7 +764,7 @@ def conv2d_transpose(input, return out -def sequence_expand(x, y, main_program=None, startup_program=None): +def sequence_expand(x, y): """Sequence Expand Layer. This layer will expand the input variable **x** according to LoD information of **y**. And the following examples will explain how sequence_expand works: @@ -808,8 +808,6 @@ def sequence_expand(x, y, main_program=None, startup_program=None): Args: x (Variable): The input variable which is a Tensor or LoDTensor. y (Variable): The input variable which is a LoDTensor. - main_program (Program): The main program. - startup_program (Program): The startup program. Returns: Variable: The expanded variable which is a LoDTensor. @@ -836,9 +834,7 @@ def lstm_unit(x_t, cell_t_prev, forget_bias=0.0, param_attr=None, - bias_attr=None, - main_program=None, - startup_program=None): + bias_attr=None): """Lstm unit layer. The equation of a lstm step is: .. math:: @@ -881,8 +877,6 @@ def lstm_unit(x_t, initializer, name etc. bias_attr (ParamAttr): The attributes of bias weights, if not False, bias weights will be created and be set to default value. - main_program (Program): The main program. - startup_program (Program): the startup program. Returns: tuple: The hidden value and cell value of lstm unit. @@ -923,18 +917,11 @@ def lstm_unit(x_t, bias_attr = ParamAttr() size = cell_t_prev.shape[1] - concat_out = concat( - input=[x_t, hidden_t_prev], - axis=1, - main_program=main_program, - startup_program=startup_program) + concat_out = concat(input=[x_t, hidden_t_prev], axis=1) fc_out = fc(input=concat_out, size=4 * size, param_attr=param_attr, - bias_attr=bias_attr, - act='linear', - main_program=main_program, - startup_program=startup_program) + bias_attr=bias_attr) dtype = x_t.dtype c = helper.create_tmp_variable(dtype) h = helper.create_tmp_variable(dtype) From 760d20de92dfb45e95aa2c3d8d86cb69b1ab5c56 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 19 Dec 2017 15:19:26 +0800 Subject: [PATCH 08/11] Add test for sequence_softmax. --- python/paddle/v2/fluid/tests/test_layers.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index d4a95bf6fc..9d2dcca56d 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -187,6 +187,15 @@ class TestBook(unittest.TestCase): x_t=x_t, hidden_t_prev=prev_hidden, cell_t_prev=prev_cell)) print(str(program)) + def test_sequence_softmax(self): + program = Program() + with program_guard(program): + seq_data = layers.data( + name='seq_data', shape=[10, 10], dtype='float32', lod_level=1) + seq = layers.fc(input=seq_data, size=20) + self.assertIsNotNone(layers.sequence_softmax(x=seq)) + print(str(program)) + if __name__ == '__main__': unittest.main() From 028604498d511658061b863de2fd88ccc26c71dc Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 19 Dec 2017 15:36:15 +0800 Subject: [PATCH 09/11] update the link of doc.paddlepaddle.org in README.md --- README.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index ceeb6d9e51..577528e7aa 100644 --- a/README.md +++ b/README.md @@ -61,32 +61,32 @@ Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddl ## Installation It is recommended to check out the -[Docker installation guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html) +[Docker installation guide](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/docker_install_en.html) before looking into the -[build from source guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html). +[build from source guide](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/build_from_source_en.html). ## Documentation -We provide [English](http://doc.paddlepaddle.org/develop/doc/) and -[Chinese](http://doc.paddlepaddle.org/doc_cn/) documentation. +We provide [English](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html) and +[Chinese](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html) documentation. -- [Deep Learning 101](http://book.paddlepaddle.org/index.html) +- [Deep Learning 101](http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.html) You might want to start from this online interactive book that can run in a Jupyter Notebook. -- [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html) +- [Distributed Training](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/usage/cluster/cluster_train_en.html) You can run distributed training jobs on MPI clusters. -- [Distributed Training on Kubernetes](http://doc.paddlepaddle.org/develop/doc/howto/usage/k8s/k8s_en.html) +- [Distributed Training on Kubernetes](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/usage/cluster/k8s_en.html) You can also run distributed training jobs on Kubernetes clusters. -- [Python API](http://doc.paddlepaddle.org/develop/doc/api/index_en.html) +- [Python API](http://www.paddlepaddle.org/docs/develop/documentation/en/api/index_en.html) Our new API enables much shorter programs. -- [How to Contribute](http://doc.paddlepaddle.org/develop/doc/howto/dev/contribute_to_paddle_en.html) +- [How to Contribute](http://www.paddlepaddle.org/docs/develop/documentation/en/howto/dev/contribute_to_paddle_en.html) We appreciate your contributions! From de85470d78014d89a64705dc10091aa94d112979 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 19 Dec 2017 16:53:09 +0800 Subject: [PATCH 10/11] Support Clip in param_attr (#6729) * Support Clip in param_attr * Fix the order of clip & regular Regular is not need to be clipped --- python/paddle/v2/fluid/__init__.py | 3 +- python/paddle/v2/fluid/clip.py | 61 +++++++++++++++++++ python/paddle/v2/fluid/framework.py | 3 + python/paddle/v2/fluid/optimizer.py | 5 ++ python/paddle/v2/fluid/param_attr.py | 9 ++- .../tests/book/test_recognize_digits_mlp.py | 4 +- 6 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 python/paddle/v2/fluid/clip.py diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index 59986c9f0c..9b3792ee9e 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -16,12 +16,13 @@ import regularizer from param_attr import ParamAttr from data_feeder import DataFeeder from core import LoDTensor, CPUPlace, GPUPlace +import clip Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + [ 'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward', 'regularizer', 'LoDTensor', 'CPUPlace', 'GPUPlace', 'Tensor', 'ParamAttr' - 'DataFeeder' + 'DataFeeder', 'clip' ] diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py new file mode 100644 index 0000000000..d7ec2fbe13 --- /dev/null +++ b/python/paddle/v2/fluid/clip.py @@ -0,0 +1,61 @@ +import functools +import layers + +__all__ = ['GradientClipByValue', 'append_gradient_clip_ops'] + + +class BaseGradientClipAttr(object): + def process_context(self, context, p_g): + raise NotImplementedError() + + def create_operators(self, param, grad): + raise NotImplementedError() + + +class NullGradientClipAttr(BaseGradientClipAttr): + def process_context(self, context, p_g): + pass + + def create_operators(self, param, grad): + return param, grad + + +class GradientClipByValue(BaseGradientClipAttr): + def __init__(self, max, min=None): + max = float(max) + if min is None: + min = -max + else: + min = float(min) + self.max = max + self.min = min + + def process_context(self, context, p_g): + pass + + def create_operators(self, param, grad): + new_grad = layers.clip(x=grad, min=self.min, max=self.max) + return param, new_grad + + +def append_gradient_clip_ops(param_grad): + context = dict() + create_op_callbacks = [] + for p, g in param_grad: + clip_attr = getattr(p, 'clip_attr', NullGradientClipAttr()) + if clip_attr is None: + clip_attr = NullGradientClipAttr() + if not isinstance(clip_attr, BaseGradientClipAttr): + raise TypeError( + "clip attribute should be an instance of BaseGradientClippingAttr" + ) + + clip_attr.process_context(context=context, p_g=param_grad) + create_op_callbacks.append( + functools.partial( + clip_attr.create_operators, param=p, grad=g)) + + return [each_callback() for each_callback in create_op_callbacks] + + +ClipByValue = GradientClipByValue diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index bf0cd275b6..973672e6e4 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -704,6 +704,7 @@ class Block(object): trainable=p.trainable, optimize_attr=p.optimize_attr, regularizer=p.regularizer, + clip_attr=p.clip_attr, name=v.name) self.vars[new_p.name] = new_p @@ -866,6 +867,8 @@ class Parameter(Variable): self.regularizer = kwargs.get('regularizer', None) + self.clip_attr = kwargs.get('clip_attr', None) + # program is a global instance. _main_program_ = Program() diff --git a/python/paddle/v2/fluid/optimizer.py b/python/paddle/v2/fluid/optimizer.py index 9f03eeea83..84fcbcdc2f 100644 --- a/python/paddle/v2/fluid/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -6,6 +6,7 @@ from framework import unique_name, program_guard from initializer import Constant from layer_helper import LayerHelper from regularizer import append_regularization_ops +from clip import append_gradient_clip_ops __all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad'] @@ -197,9 +198,13 @@ class Optimizer(object): `create_optimization_pass()` into one. """ params_grads = append_backward_ops(loss, parameter_list, no_grad_set) + + params_grads = append_gradient_clip_ops(params_grads) + # Add regularization if any params_grads = append_regularization_ops(params_grads, self.regularization) + optimize_ops = self.create_optimization_pass(params_grads, loss, startup_program) return optimize_ops diff --git a/python/paddle/v2/fluid/param_attr.py b/python/paddle/v2/fluid/param_attr.py index 7952a5ea51..f6f320c788 100644 --- a/python/paddle/v2/fluid/param_attr.py +++ b/python/paddle/v2/fluid/param_attr.py @@ -1,6 +1,8 @@ from initializer import Initializer, Xavier, Constant from regularizer import WeightDecayRegularizer +__all__ = ['ParamAttr'] + class ParamAttr(object): def __init__(self, @@ -8,12 +10,14 @@ class ParamAttr(object): initializer=None, learning_rate=1.0, regularizer=None, - trainable=True): + trainable=True, + clip=None): self.name = name self.initializer = initializer self.learning_rate = learning_rate self.regularizer = regularizer self.trainable = trainable + self.clip = clip def set_default_initializer(self, initializer): if initializer is None: @@ -56,7 +60,8 @@ class ParamAttr(object): 'name': self.name, 'learning_rate': self.learning_rate, 'regularizer': self.regularizer, - 'trainable': self.trainable + 'trainable': self.trainable, + 'clip_attr': self.clip } if with_initializer: kwargs['initializer'] = self.initializer diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index d77f19660e..fc073f6be8 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -11,7 +11,9 @@ regularizer = fluid.regularizer.L2Decay(0.0005 * BATCH_SIZE) hidden1 = fluid.layers.fc(input=image, size=128, act='relu', - param_attr=regularizer) + param_attr=fluid.ParamAttr( + regularizer=regularizer, + clip=fluid.clip.ClipByValue(10))) hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu', From 495259703c8c01b5dd24d25f4ce42c0fe0cd5882 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 19 Dec 2017 17:01:35 +0800 Subject: [PATCH 11/11] fix some doc errors --- doc/howto/usage/cluster/cluster_train_cn.md | 12 ++++++------ doc/howto/usage/cluster/cluster_train_en.md | 3 +-- doc/howto/usage/cluster/k8s_cn.md | 14 +++++++------- doc/howto/usage/cluster/k8s_en.md | 14 +++++++------- 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/doc/howto/usage/cluster/cluster_train_cn.md b/doc/howto/usage/cluster/cluster_train_cn.md index c9f90538a6..659bae9c0c 100644 --- a/doc/howto/usage/cluster/cluster_train_cn.md +++ b/doc/howto/usage/cluster/cluster_train_cn.md @@ -1,4 +1,4 @@ -# PaddlePaddle分布式训练 +# 分布式训练 ## 概述 @@ -181,8 +181,8 @@ PaddlePaddle可以使用多种分布式计算平台构建分布式计算任务 ## 在不同集群中运行 - - [fabric](fabric_cn.md) - - [openmpi](openmpi_cn.md) - - [kubernetes](k8s_cn.md) - - [kubernetes distributed](k8s_distributed_cn.md) - - [kubernetes on AWS](k8s_aws_cn.md) + - [fabric集群](fabric_cn.md) + - [openmpi集群](openmpi_cn.md) + - [kubernetes单机](k8s_cn.md) + - [kubernetes distributed分布式](k8s_distributed_cn.md) + - [AWS上运行kubernetes集群训练](k8s_aws_cn.md) diff --git a/doc/howto/usage/cluster/cluster_train_en.md b/doc/howto/usage/cluster/cluster_train_en.md index f9819470c0..915405ca5b 100644 --- a/doc/howto/usage/cluster/cluster_train_en.md +++ b/doc/howto/usage/cluster/cluster_train_en.md @@ -1,4 +1,4 @@ -# PaddlePaddle Distributed Training +# Distributed Training ## Introduction @@ -188,5 +188,4 @@ These cluster platforms provide API or environment variables for training proces - [fabric](fabric_en.md) - [openmpi](openmpi_en.md) - [kubernetes](k8s_en.md) - - kubernetes distributed - [kubernetes on AWS](k8s_aws_en.md) diff --git a/doc/howto/usage/cluster/k8s_cn.md b/doc/howto/usage/cluster/k8s_cn.md index ab07cb9cd5..9d49d0fa8c 100644 --- a/doc/howto/usage/cluster/k8s_cn.md +++ b/doc/howto/usage/cluster/k8s_cn.md @@ -1,16 +1,16 @@ # Kubernetes单机训练 -在这篇文档里,我们介绍如何在 Kubernetes 集群上启动一个单机使用CPU的Paddle训练作业。在下一篇中,我们将介绍如何启动分布式训练作业。 +在这篇文档里,我们介绍如何在 Kubernetes 集群上启动一个单机使用CPU的PaddlePaddle训练作业。在下一篇中,我们将介绍如何启动分布式训练作业。 ## 制作Docker镜像 -在一个功能齐全的Kubernetes机群里,通常我们会安装Ceph等分布式文件系统来存储训练数据。这样的话,一个分布式Paddle训练任务中的每个进程都可以从Ceph读取数据。在这个例子里,我们只演示一个单机作业,所以可以简化对环境的要求,把训练数据直接放在 -Paddle的Docker image里。为此,我们需要制作一个包含训练数据的Paddle镜像。 +在一个功能齐全的Kubernetes机群里,通常我们会安装Ceph等分布式文件系统来存储训练数据。这样的话,一个分布式PaddlePaddle训练任务中的每个进程都可以从Ceph读取数据。在这个例子里,我们只演示一个单机作业,所以可以简化对环境的要求,把训练数据直接放在 +PaddlePaddle的Docker image里。为此,我们需要制作一个包含训练数据的PaddlePaddle镜像。 -Paddle 的 [Quick Start Tutorial](http://www.paddlepaddle.org/doc/demo/quick_start/index_en.html) +Paddle 的 [Quick Start Tutorial](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html) 里介绍了用Paddle源码中的脚本下载训练数据的过程。 -而 `paddledev/paddle:cpu-demo-latest` 镜像里有 Paddle 源码与demo,( 请注意,默认的 -Paddle镜像 `paddledev/paddle:cpu-latest` 是不包括源码的, Paddle的各版本镜像可以参考 [Docker installation guide](http://www.paddlepaddle.org/doc/build/docker_install.html) ),所以我们使用这个镜像来下载训练数据到Docker container中,然后把这个包含了训练数据的container保存为一个新的镜像。 +而 `paddledev/paddle:cpu-demo-latest` 镜像里有 PaddlePaddle 源码与demo,( 请注意,默认的 +PaddlePaddle镜像 `paddledev/paddle:cpu-latest` 是不包括源码的, PaddlePaddle的各版本镜像可以参考 [Docker installation guide](http://www.paddlepaddle.org/doc/build/docker_install.html) ),所以我们使用这个镜像来下载训练数据到Docker container中,然后把这个包含了训练数据的container保存为一个新的镜像。 ### 运行容器 @@ -103,7 +103,7 @@ spec: restartPolicy: Never ``` -### 创建Paddle Job +### 创建PaddlePaddle Job 使用上文创建的yaml文件创建Kubernetes Job,命令为: diff --git a/doc/howto/usage/cluster/k8s_en.md b/doc/howto/usage/cluster/k8s_en.md index 0c3ab05b70..5a3ebfd8dc 100644 --- a/doc/howto/usage/cluster/k8s_en.md +++ b/doc/howto/usage/cluster/k8s_en.md @@ -1,13 +1,13 @@ -# Paddle On Kubernetes +# PaddlePaddle On Kubernetes ->In this article, we will introduce how to run Paddle training job on single CPU machine using Kubernetes. In next article, we will introduce how to run Paddle training job on distributed cluster. +In this article, we will introduce how to run PaddlePaddle training job on single CPU machine using Kubernetes. In next article, we will introduce how to run PaddlePaddle training job on distributed cluster. ## Build Docker Image -In distributed Kubernetes cluster, we will use Ceph or other shared storage system for storing training related data so that all processes in Paddle training can retrieve data from Ceph. In this example, we will only demo training job on single machine. In order to simplify the requirement of the environment, we will directly put training data into Paddle's Docker Image, so we need to create a Paddle Docker image that already includes the training data. +In distributed Kubernetes cluster, we will use Ceph or other shared storage system for storing training data so that all processes in the training job can retrieve data from Ceph. In this example, we will only demo training job on single machine. In order to simplify the requirement of the environment, we will directly put training data into PaddlePaddle's Docker Image, so we need to create a PaddlePaddle Docker image that already includes the training data. -Paddle's [Quick Start Tutorial](http://www.paddlepaddle.org/doc/demo/quick_start/index_en.html) introduces how to download and train data by using script from Paddle's source code. -And `paddledev/paddle:cpu-demo-latest` image has the Paddle source code and demo. (Caution: Default Paddle image `paddledev/paddle:cpu-latest` doesn't include the source code, Paddle's different versions of image can be referred here: [Docker installation guide](http://www.paddlepaddle.org/doc/build/docker_install.html)), so we run this container and download the training data, and then commit the whole container to be a new Docker image. +PaddlePaddle's [Quick Start Tutorial](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html) introduces how to download and train data by using script from PaddlePaddle's source code. +And `paddledev/paddle:cpu-demo-latest` image has the PaddlePaddle source code and demo. (Caution: Default PaddlePaddle image `paddledev/paddle:cpu-latest` doesn't include the source code, PaddlePaddle's different versions of image can be referred here: [Docker installation guide](http://www.paddlepaddle.org/doc/build/docker_install.html)), so we run this container and download the training data, and then commit the whole container to be a new Docker image. ### Run Docker Container @@ -67,7 +67,7 @@ $ docker commit quick_start_data mypaddle/paddle:quickstart ## Use Kubernetes For Training ->We will use Kubernetes job for training process, following steps shows how to do the training with Kubernetes. +We will use Kubernetes job for training process, following steps shows how to do the training with Kubernetes. ### Create Yaml Files @@ -99,7 +99,7 @@ spec: restartPolicy: Never ``` -### Start Paddle Job +### Start PaddlePaddle Job Using the above yaml file to start the Kubernetes job.