From dfc5d1f19abe241e1a8e5c1f6bcf26e09d4f0540 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Thu, 16 Nov 2017 17:46:59 +0800 Subject: [PATCH 1/7] add the l2 distance layer. --- paddle/gserver/layers/L2DistanceLayer.cpp | 92 +++++++++++++++++++++++ paddle/gserver/layers/L2DistanceLayer.h | 53 +++++++++++++ paddle/gserver/tests/test_LayerGrad.cpp | 20 +++++ 3 files changed, 165 insertions(+) create mode 100644 paddle/gserver/layers/L2DistanceLayer.cpp create mode 100644 paddle/gserver/layers/L2DistanceLayer.h diff --git a/paddle/gserver/layers/L2DistanceLayer.cpp b/paddle/gserver/layers/L2DistanceLayer.cpp new file mode 100644 index 0000000000..e76e29cbe5 --- /dev/null +++ b/paddle/gserver/layers/L2DistanceLayer.cpp @@ -0,0 +1,92 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "L2DistanceLayer.h" +#include "paddle/utils/Logging.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(l2_distance, L2DistanceLayer); + +bool L2DistanceLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + + CHECK_EQ(inputLayers_.size(), 2UL) << "The L2 distance layer accepts two and " + << "only two inputs."; + CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2 distance" + << "is fixed to be 1."; + + return true; +} + +void L2DistanceLayer::forward(PassType passType) { + Layer::forward(passType); + + const auto inV1 = getInputValue(0); + const auto inV2 = getInputValue(1); + + CHECK(inV1 && inV2); + CHECK_EQ(inV1->getHeight(), inV2->getHeight()) + << "The height of two inputs to this layer must be the same."; + CHECK_EQ(inV1->getWidth(), inV2->getWidth()) + << "The width of two inputs to this layer must be the same."; + + int batchSize = inV1->getHeight(); + int output_dim = getSize(); + { + REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str()); + reserveOutput(batchSize, output_dim); + auto outV = getOutputValue(); + CHECK(outV) << "The output matrix should not be null."; + + Matrix::resizeOrCreate( + inputSub_, inV1->getHeight(), inV1->getWidth(), false, useGpu_); + + inputSub_->assign(*inV1); + inputSub_->sub(*inV2); + outV->sumOfProducts(*inputSub_, *inputSub_, 1, 0); + outV->sqrt2(*outV); + } +} + +void L2DistanceLayer::backward(const UpdateCallback& callback) { + const auto outG = getOutputGrad(); + const auto outV = getOutputValue(); + const auto inV1 = getInputValue(0); + const auto inV2 = getInputValue(1); + auto inGrad1 = getInputGrad(0); + auto inGrad2 = getInputGrad(1); + CHECK(outG && outV && inV1 && inV2 && inGrad1 && inGrad2); + + { + REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str()); + + outV->scalarDiv(*outV, 1.); + outV->dotMul(*outG, *outV); + + if (inGrad1) { + inGrad1->addRowScale(0, *inputSub_, *outV); + } + + if (inGrad2) { + inputSub_->mulScalar(-1.); + inGrad2->addRowScale(0, *inputSub_, *outV); + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/L2DistanceLayer.h b/paddle/gserver/layers/L2DistanceLayer.h new file mode 100644 index 0000000000..64731db2bf --- /dev/null +++ b/paddle/gserver/layers/L2DistanceLayer.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" +#include "paddle/math/Matrix.h" +#include "paddle/utils/ThreadLocal.h" + +namespace paddle { + +/** + * @brief A layer for calculating l2 distance between the two input vectors. + * \f[ + * f(\bf{x}, \bf{y}) = \sqrt{\sum_{i=1}^D(x_i - y_i)} + * \f] + * + * - Input1: A vector (batchSize * dataDim) + * - Input2: A vector (batchSize * dataDim) + * - Output: A vector (batchSize * 1) + * + * The config file api is l2_distance. + */ + +class L2DistanceLayer : public Layer { +public: + explicit L2DistanceLayer(const LayerConfig& config) : Layer(config) {} + + ~L2DistanceLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + +private: + // Store result of subtracting Input2 from Input1. + MatrixPtr inputSub_; +}; + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 3517d293e3..18f8d602b2 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -583,6 +583,7 @@ TEST(Layer, maxoutLayer) { testLayerGrad(config, "maxout", 10, false, useGpu); } } + void testFcLayer(string format, size_t nnz) { TestConfig config; config.biasSize = 1024; @@ -2429,6 +2430,25 @@ TEST(Layer, ScaleSubRegionLayer) { } } +TEST(Layer, L2DistanceLayer) { + TestConfig config; + config.layerConfig.set_type("l2_distance"); + config.layerConfig.set_size(1); + config.biasSize = 0; + + const size_t input_dim = 27; + const size_t batch_size = 11; + + config.inputDefs.push_back({INPUT_DATA, "layer_0", input_dim, 0}); + config.inputDefs.push_back({INPUT_DATA, "layer_1", input_dim, 0}); + config.layerConfig.add_inputs(); + config.layerConfig.add_inputs(); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "l2_distance", batch_size, false, useGpu); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); From 4772b78ced64e8c0382d6ccf2f2ccdfa9022c098 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Fri, 17 Nov 2017 15:04:49 +0800 Subject: [PATCH 2/7] add config_helper. --- doc/api/v2/config/layer.rst | 5 ++ paddle/gserver/layers/L2DistanceLayer.cpp | 23 +++++---- paddle/gserver/layers/L2DistanceLayer.h | 9 ++-- python/paddle/trainer/config_parser.py | 38 +++++++++----- .../paddle/trainer_config_helpers/layers.py | 49 ++++++++++++++++++- .../tests/configs/file_list.sh | 3 +- .../protostr/test_l2_distance_layer.protostr | 39 +++++++++++++++ .../tests/configs/test_l2_distance_layer.py | 7 +++ 8 files changed, 142 insertions(+), 31 deletions(-) create mode 100644 python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr create mode 100644 python/paddle/trainer_config_helpers/tests/configs/test_l2_distance_layer.py diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index 203506d7ab..3bb5270797 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -372,6 +372,11 @@ cos_sim .. autoclass:: paddle.v2.layer.cos_sim :noindex: +l2_distance +----------- +.. autoclass:: paddle.v2.layer.l2_distance + :noindex: + trans ----- .. autoclass:: paddle.v2.layer.trans diff --git a/paddle/gserver/layers/L2DistanceLayer.cpp b/paddle/gserver/layers/L2DistanceLayer.cpp index e76e29cbe5..c71df1b92c 100644 --- a/paddle/gserver/layers/L2DistanceLayer.cpp +++ b/paddle/gserver/layers/L2DistanceLayer.cpp @@ -25,9 +25,9 @@ bool L2DistanceLayer::init(const LayerMap& layerMap, /* Initialize the basic parent class */ Layer::init(layerMap, parameterMap); - CHECK_EQ(inputLayers_.size(), 2UL) << "The L2 distance layer accepts two and " + CHECK_EQ(inputLayers_.size(), 2UL) << "The L2DistanceLayer accepts two and " << "only two inputs."; - CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2 distance" + CHECK_EQ(getSize(), 1UL) << "The output dimensionality of L2DistanceLayer " << "is fixed to be 1."; return true; @@ -41,9 +41,9 @@ void L2DistanceLayer::forward(PassType passType) { CHECK(inV1 && inV2); CHECK_EQ(inV1->getHeight(), inV2->getHeight()) - << "The height of two inputs to this layer must be the same."; + << "The height of two inputs of this layer must be the same."; CHECK_EQ(inV1->getWidth(), inV2->getWidth()) - << "The width of two inputs to this layer must be the same."; + << "The width of two inputs of this layer must be the same."; int batchSize = inV1->getHeight(); int output_dim = getSize(); @@ -66,22 +66,21 @@ void L2DistanceLayer::forward(PassType passType) { void L2DistanceLayer::backward(const UpdateCallback& callback) { const auto outG = getOutputGrad(); const auto outV = getOutputValue(); - const auto inV1 = getInputValue(0); - const auto inV2 = getInputValue(1); + CHECK(outG && outV); + auto inGrad1 = getInputGrad(0); auto inGrad2 = getInputGrad(1); - CHECK(outG && outV && inV1 && inV2 && inGrad1 && inGrad2); { REGISTER_TIMER_INFO("L2DistanceBpAtvTimer", getName().c_str()); - outV->scalarDiv(*outV, 1.); - outV->dotMul(*outG, *outV); - - if (inGrad1) { - inGrad1->addRowScale(0, *inputSub_, *outV); + if (inGrad1 || inGrad2) { + outV->scalarDiv(*outV, 1.); + outV->dotMul(*outG, *outV); } + if (inGrad1) inGrad1->addRowScale(0, *inputSub_, *outV); + if (inGrad2) { inputSub_->mulScalar(-1.); inGrad2->addRowScale(0, *inputSub_, *outV); diff --git a/paddle/gserver/layers/L2DistanceLayer.h b/paddle/gserver/layers/L2DistanceLayer.h index 64731db2bf..9b12847a10 100644 --- a/paddle/gserver/layers/L2DistanceLayer.h +++ b/paddle/gserver/layers/L2DistanceLayer.h @@ -16,12 +16,11 @@ limitations under the License. */ #include "Layer.h" #include "paddle/math/Matrix.h" -#include "paddle/utils/ThreadLocal.h" namespace paddle { /** - * @brief A layer for calculating l2 distance between the two input vectors. + * @brief The layer calculates the l2 distance between two input vectors. * \f[ * f(\bf{x}, \bf{y}) = \sqrt{\sum_{i=1}^D(x_i - y_i)} * \f] @@ -30,13 +29,12 @@ namespace paddle { * - Input2: A vector (batchSize * dataDim) * - Output: A vector (batchSize * 1) * - * The config file api is l2_distance. + * The configuration api is: l2_distance_layer. */ class L2DistanceLayer : public Layer { public: explicit L2DistanceLayer(const LayerConfig& config) : Layer(config) {} - ~L2DistanceLayer() {} bool init(const LayerMap& layerMap, @@ -46,7 +44,8 @@ public: void backward(const UpdateCallback& callback = nullptr) override; private: - // Store result of subtracting Input2 from Input1. + // Store the result of subtracting Input2 from Input1 in forward computation, + // which will be reused in backward computation. MatrixPtr inputSub_; }; diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 5bd68e211a..7dd4e3d00c 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3330,6 +3330,18 @@ class RowL2NormLayer(LayerBase): self.set_layer_size(input_layer.size) +@config_layer('cos') +class CosSimLayer(LayerBase): + def __init__(self, name, inputs, cos_scale=1, device=None): + super(CosSimLayer, self).__init__( + name, 'cos', 1, inputs=inputs, device=device) + config_assert(len(self.inputs) == 2, 'CosSimLayer must have 2 inputs') + config_assert( + self.get_input_layer(0).size == self.get_input_layer(1).size, + 'inputs of CosSimLayer must have same dim') + self.config.cos_scale = cos_scale + + @config_layer('cos_vm') class CosSimVecMatLayer(LayerBase): def __init__(self, name, size, inputs, cos_scale=1.0, device=None): @@ -3343,6 +3355,20 @@ class CosSimVecMatLayer(LayerBase): 'Wrong input size for CosSimVecMatLayer') +@config_layer('l2_distance') +class L2DistanceLayer(LayerBase): + def __init__(self, name, inputs, device=None): + super(L2DistanceLayer, self).__init__( + name, 'l2_distance', 1, inputs=inputs, device=device) + config_assert( + len(self.inputs) == 2, ('The L2DistanceLayer must have ' + 'and only have 2 inputs.')) + config_assert( + self.get_input_layer(0).size == self.get_input_layer(1).size, + ('Two inputs of the L2DistanceLayer must have ' + 'the same dimensionality.')) + + @config_layer('sampling_id') class SamplingIdLayer(LayerBase): def __init__(self, name, inputs, device=None): @@ -3384,18 +3410,6 @@ class AverageLayer(LayerBase): self.create_bias_parameter(bias, self.config.size) -@config_layer('cos') -class CosSimLayer(LayerBase): - def __init__(self, name, inputs, cos_scale=1, device=None): - super(CosSimLayer, self).__init__( - name, 'cos', 1, inputs=inputs, device=device) - config_assert(len(self.inputs) == 2, 'CosSimLayer must have 2 inputs') - config_assert( - self.get_input_layer(0).size == self.get_input_layer(1).size, - 'inputs of CosSimLayer must have same dim') - self.config.cos_scale = cos_scale - - @config_layer('tensor') class TensorLayer(LayerBase): def __init__(self, name, size, inputs, bias=True, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 5de1c18950..5ed6fe384a 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -51,6 +51,7 @@ __all__ = [ 'last_seq', 'first_seq', 'cos_sim', + 'l2_distance_layer', 'hsigmoid', 'conv_projection', 'square_error_cost', @@ -167,6 +168,7 @@ class LayerType(object): COST = 'cost' COSINE_SIM_VEC = 'cos_vm' COSINE_SIM = 'cos' + L2_DISTANCE = 'l2_distance' HSIGMOID = 'hsigmoid' CONV_LAYER = 'conv' CONVTRANS_LAYER = 'convt' @@ -2332,6 +2334,51 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None): return LayerOutput(name, LayerType.COSINE_SIM, parents=[a, b], size=size) +@wrap_name_default() +@layer_support() +def l2_distance_layer(x, y, name=None, layer_attr=None): + """ + This layer calculate and return the Euclidean distance between two input + vectors a and b. The equation is as follows: + + .. math:: + l2_distance(\\mathbf{x}, \\mathbf{y}) = \\sqrt{\\sum_{i=1}^D(x_i - y_i)} + + The output size of this layer is fixed to be 1. Note that the above + computation is for one sample. Multiple samples are processed in one batch. + + The example usage is: + + .. code-block:: python + + l2_sim = l2_distance(x=layer1, y=layer2) + + :param name: The name of this layer. It is optional. + :type name: basestring + :param x: The first input x for this layer, whose output is a matrix with + dimensionality N x D. N is the sample number in a mini-batch. + D is the dimensionality of x's output. + :type x: LayerOutput + :param y: The second input y for this layer, whose output is a matrix with + dimensionality N x D. N is the sample number in a mini-batch. + D is the dimensionality of y's output. + :type y: LayerOutput + :param layer_attr: The extra layer attributes, for example, drop rate. + See ExtraLayerAttribute for more details. + :type layer_attr: ExtraLayerAttribute + :return: The returned LayerOutput object. + :rtype: LayerOutput + """ + + assert isinstance(x, LayerOutput) and isinstance(x, LayerOutput) + Layer( + name=name, + type=LayerType.L2_DISTANCE, + inputs=[x.name, x.name], + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput(name, LayerType.L2_DISTANCE, parents=[x, y], size=1) + + @wrap_name_default() @wrap_bias_attr_default(has_bias=True) @wrap_param_attr_default() @@ -3867,7 +3914,7 @@ def recurrent_layer(input, :type input: LayerOutput :param act: Activation type. TanhActivation is the default activation. :type act: BaseActivation - :param bias_attr: The parameter attribute for bias. If this parameter is set to + :param bias_attr: The parameter attribute for bias. If this parameter is set to False or an object whose type is not ParameterAttribute, no bias is defined. If the parameter is set to True, the bias is initialized to zero. diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 1c7451e0ab..5014c14b8f 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -10,6 +10,7 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer test_seq_slice_layer test_cross_entropy_over_beam test_roi_pool_layer test_pooling3D_layer -test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer test_scale_sub_region_layer) +test_conv3d_layer test_deconv3d_layer test_BatchNorm3D test_resize_layer +test_scale_sub_region_layer test_l2_distance_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr new file mode 100644 index 0000000000..ad488bfa9f --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr @@ -0,0 +1,39 @@ +type: "nn" +layers { + name: "x" + type: "data" + size: 128 + active_type: "" +} +layers { + name: "y" + type: "data" + size: 128 + active_type: "" +} +layers { + name: "__l2_distance_layer_0__" + type: "l2_distance" + size: 1 + active_type: "" + inputs { + input_layer_name: "x" + } + inputs { + input_layer_name: "x" + } +} +input_layer_names: "x" +input_layer_names: "y" +output_layer_names: "__l2_distance_layer_0__" +sub_models { + name: "root" + layer_names: "x" + layer_names: "y" + layer_names: "__l2_distance_layer_0__" + input_layer_names: "x" + input_layer_names: "y" + output_layer_names: "__l2_distance_layer_0__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_l2_distance_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_l2_distance_layer.py new file mode 100644 index 0000000000..b36a5c6d12 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_l2_distance_layer.py @@ -0,0 +1,7 @@ +from paddle.trainer_config_helpers import * + +outputs( + l2_distance_layer( + x=data_layer( + name='x', size=128), y=data_layer( + name='y', size=128))) From 929efdc592aa3d99e821d07b34234c0e60d0f085 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Fri, 17 Nov 2017 17:53:59 +0800 Subject: [PATCH 3/7] follow comments. --- python/paddle/trainer/config_parser.py | 2 +- python/paddle/trainer_config_helpers/layers.py | 4 ++-- .../tests/configs/protostr/test_l2_distance_layer.protostr | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 7dd4e3d00c..42aac59d22 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3338,7 +3338,7 @@ class CosSimLayer(LayerBase): config_assert(len(self.inputs) == 2, 'CosSimLayer must have 2 inputs') config_assert( self.get_input_layer(0).size == self.get_input_layer(1).size, - 'inputs of CosSimLayer must have same dim') + 'The two inputs of CosSimLayer must have the same dimensionality.') self.config.cos_scale = cos_scale diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 5ed6fe384a..e8f4f0035d 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2338,7 +2338,7 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None): @layer_support() def l2_distance_layer(x, y, name=None, layer_attr=None): """ - This layer calculate and return the Euclidean distance between two input + This layer calculates and returns the Euclidean distance between two input vectors a and b. The equation is as follows: .. math:: @@ -2374,7 +2374,7 @@ def l2_distance_layer(x, y, name=None, layer_attr=None): Layer( name=name, type=LayerType.L2_DISTANCE, - inputs=[x.name, x.name], + inputs=[x.name, y.name], **ExtraLayerAttribute.to_kwargs(layer_attr)) return LayerOutput(name, LayerType.L2_DISTANCE, parents=[x, y], size=1) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr index ad488bfa9f..9ba33689ed 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_l2_distance_layer.protostr @@ -20,7 +20,7 @@ layers { input_layer_name: "x" } inputs { - input_layer_name: "x" + input_layer_name: "y" } } input_layer_names: "x" From 37190b7c1455de51f0d89f2f12581d41b041b075 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Fri, 17 Nov 2017 18:08:57 +0800 Subject: [PATCH 4/7] small fix. --- python/paddle/trainer_config_helpers/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 5b39a65d8c..14cdee4c55 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2341,7 +2341,7 @@ def cos_sim(a, b, scale=1, size=1, name=None, layer_attr=None): def l2_distance_layer(x, y, name=None, layer_attr=None): """ This layer calculates and returns the Euclidean distance between two input - vectors a and b. The equation is as follows: + vectors x and y. The equation is as follows: .. math:: l2_distance(\\mathbf{x}, \\mathbf{y}) = \\sqrt{\\sum_{i=1}^D(x_i - y_i)} @@ -2372,7 +2372,7 @@ def l2_distance_layer(x, y, name=None, layer_attr=None): :rtype: LayerOutput """ - assert isinstance(x, LayerOutput) and isinstance(x, LayerOutput) + assert isinstance(x, LayerOutput) and isinstance(y, LayerOutput) Layer( name=name, type=LayerType.L2_DISTANCE, From bf5f94a3cab48a64586d1d4052db0caafac69e27 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 17 Nov 2017 18:36:09 +0800 Subject: [PATCH 5/7] fix compiler error in "WITH_MKL" --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ae8728f4d4..65164b8472 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,7 +109,7 @@ else() endif() set(WITH_MKLML ${WITH_MKL}) -if (WITH_MKL AND ${AVX2_FOUND}) +if (WITH_MKL AND AVX2_FOUND) set(WITH_MKLDNN ON) else() message(STATUS "Do not have AVX2 intrinsics and disabled MKL-DNN") From 6cfcf6245a67eb39cf5667adb011069c76e55c03 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Sat, 18 Nov 2017 19:02:46 +0530 Subject: [PATCH 6/7] Adding logical operators for beam search and control flow (#5708) --- paddle/framework/data_type.h | 5 + paddle/operators/CMakeLists.txt | 5 + paddle/operators/logical_op.cc | 153 ++++++++++++++++++ paddle/operators/logical_op.cu | 24 +++ paddle/operators/logical_op.h | 93 +++++++++++ .../paddle/v2/fluid/tests/test_logical_op.py | 35 ++++ 6 files changed, 315 insertions(+) create mode 100644 paddle/operators/logical_op.cc create mode 100644 paddle/operators/logical_op.cu create mode 100644 paddle/operators/logical_op.h create mode 100644 python/paddle/v2/fluid/tests/test_logical_op.py diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h index be144d8fc0..c54d2d4ddf 100644 --- a/paddle/framework/data_type.h +++ b/paddle/framework/data_type.h @@ -46,6 +46,8 @@ inline std::type_index ToTypeIndex(DataType type) { return typeid(int); case DataType::INT64: return typeid(int64_t); + case DataType::BOOL: + return typeid(bool); default: PADDLE_THROW("Not support type %d", type); } @@ -66,6 +68,9 @@ inline void VisitDataType(DataType type, Visitor visitor) { case DataType::INT64: visitor.template operator()(); break; + case DataType::BOOL: + visitor.template operator()(); + break; default: PADDLE_THROW("Not supported"); } diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 46c2833030..d0fe5b4635 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -87,6 +87,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") endif() + if ("${TARGET}" STREQUAL "logical_op") + set(pybind_flag 1) + file(APPEND ${pybind_file} "USE_OP(logical_and);\n") + endif() + # pool_with_index_op contains several operators if ("${TARGET}" STREQUAL "pool_with_index_op") set(pybind_flag 1) diff --git a/paddle/operators/logical_op.cc b/paddle/operators/logical_op.cc new file mode 100644 index 0000000000..a37582c1d8 --- /dev/null +++ b/paddle/operators/logical_op.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/logical_op.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { +template +class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + BinaryLogicalOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + OpComment comment; + AddInput("X", + string::Sprintf("(LoDTensor) Left hand operand of %s operator", + comment.type)); + AddInput("Y", + string::Sprintf("(LoDTensor) Right hand operand of %s operator", + comment.type)); + AddOutput("Out", string::Sprintf( + "(LoDTensor) n-dim bool tensor. Each element is %s", + comment.equation)); + AddComment(string::Sprintf(R"DOC(%s Operator + +It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim boolean tensors. +Each element of Out is calculated by %s +)DOC", + comment.type, comment.equation)); + } +}; + +template +class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + UnaryLogicalOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + OpComment comment; + AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator", + comment.type)); + AddOutput("Out", string::Sprintf( + "(LoDTensor) n-dim bool tensor. Each element is %s", + comment.equation)); + AddComment(string::Sprintf(R"DOC(%s Operator + +It operates element-wise on X, and returns the Out. X and Out are N-dim boolean tensors. +Each element of Out is calculated by %s +)DOC", + comment.type, comment.equation)); + } +}; + +template +class BinaryLogicalOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + OpComment comment; + PADDLE_ENFORCE(context->HasInput("X"), + "Input(X) of %s operator must not be null", comment.type); + PADDLE_ENFORCE(context->HasInput("Y"), + "Input(Y) of %s operator must not be null", comment.type); + auto dim_x = context->GetInputDim("X"); + auto dim_y = context->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y), + "The number of elements in X and Y should be same"); + + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } +}; + +template +class UnaryLogicalOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + OpComment comment; + PADDLE_ENFORCE(context->HasInput("X"), + "Input(X) of %s operator must not be null", comment.type); + auto dim_x = context->GetInputDim("X"); + + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } +}; + +class LogicalOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx); + // LogicalOp kernel's device type is decided by input tensor place + kt.place_ = ctx.Input("X")->place(); + return kt; + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_BINARY_LOGICAL_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::LogicalOp, \ + ::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \ + ::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker); + +#define REGISTER_UNARY_LOGICAL_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::LogicalOp, \ + ::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \ + ::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker); + +REGISTER_BINARY_LOGICAL_OP(logical_and, "Out = X && Y"); +REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU, + paddle::operators::LogicalAndFunctor); +REGISTER_BINARY_LOGICAL_OP(logical_or, "Out = X && Y"); +REGISTER_BINARY_LOGICAL_KERNEL(logical_or, CPU, + paddle::operators::LogicalOrFunctor); +REGISTER_UNARY_LOGICAL_OP(logical_not, "Out = !X"); +REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU, + paddle::operators::LogicalNotFunctor); +REGISTER_BINARY_LOGICAL_OP(logical_xor, "Out = (X || Y) && !(X && Y)"); +REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU, + paddle::operators::LogicalXorFunctor); diff --git a/paddle/operators/logical_op.cu b/paddle/operators/logical_op.cu new file mode 100644 index 0000000000..d41239b2ca --- /dev/null +++ b/paddle/operators/logical_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/logical_op.h" + +REGISTER_BINARY_LOGICAL_KERNEL(logical_and, GPU, + paddle::operators::LogicalAndFunctor); +REGISTER_BINARY_LOGICAL_KERNEL(logical_or, GPU, + paddle::operators::LogicalOrFunctor); +REGISTER_UNARY_LOGICAL_KERNEL(logical_not, GPU, + paddle::operators::LogicalNotFunctor); +REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, GPU, + paddle::operators::LogicalXorFunctor); diff --git a/paddle/operators/logical_op.h b/paddle/operators/logical_op.h new file mode 100644 index 0000000000..6e78a7d6ed --- /dev/null +++ b/paddle/operators/logical_op.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace operators { + +template +struct LogicalAndFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; } +}; + +template +struct LogicalOrFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; } +}; + +template +struct LogicalNotFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a) const { return !a; } +}; + +template +struct LogicalXorFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { + return (a || b) && !(a && b); + } +}; + +template +class BinaryLogicalOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* out = context.Output("Out"); + Functor binary_func; + platform::Transform trans; + trans(context.device_context(), x->data(), x->data() + x->numel(), + y->data(), out->mutable_data(context.GetPlace()), + binary_func); + } +}; + +template +class UnaryLogicalOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + Functor unary_func; + platform::Transform trans; + trans(context.device_context(), x->data(), x->data() + x->numel(), + out->mutable_data(context.GetPlace()), unary_func); + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##Place, functor>); + +#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##Place, functor>); diff --git a/python/paddle/v2/fluid/tests/test_logical_op.py b/python/paddle/v2/fluid/tests/test_logical_op.py new file mode 100644 index 0000000000..ac90bf839c --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_logical_op.py @@ -0,0 +1,35 @@ +import op_test +import unittest +import numpy as np + + +def create_test_class(op_type, callback, binary_op=True): + class Cls(op_test.OpTest): + def setUp(self): + a = np.random.choice(a=[True, False], size=(10, 7)).astype(bool) + if binary_op: + b = np.random.choice(a=[True, False], size=(10, 7)).astype(bool) + c = callback(a, b) + else: + c = callback(a) + self.outputs = {'Out': c} + self.op_type = op_type + if binary_op: + self.inputs = {'X': a, 'Y': b} + else: + self.inputs = {'X': a} + + def test_output(self): + self.check_output() + + Cls.__name__ = op_type + globals()[op_type] = Cls + + +create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b)) +create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b)) +create_test_class('logical_not', lambda _a: np.logical_not(_a), False) +create_test_class('logical_xor', lambda _a, _b: np.logical_xor(_a, _b)) + +if __name__ == '__main__': + unittest.main() From 569f7c4773e877d120017d3b22b7df793c02e3ec Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sat, 18 Nov 2017 09:35:21 -0600 Subject: [PATCH 7/7] enforce shape of backward target to be {1} (#5745) * enforce shape of backward target to be {1} * fix test_regularizer.py * rm unused code * fix backward_test * fix a type bug * fix test_program --- paddle/framework/backward.cc | 11 ++--- paddle/framework/backward_test.cc | 7 +++ .../paddle/v2/fluid/tests/test_optimizer.py | 48 +++++++++++++++---- python/paddle/v2/fluid/tests/test_program.py | 14 ++++-- .../paddle/v2/fluid/tests/test_regularizer.py | 12 ++++- 5 files changed, 69 insertions(+), 23 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 00d9dd238e..b9018ecdba 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -513,19 +513,14 @@ ParamGradInfoMap AppendBackward( const int root_block_idx = 0; auto root_block = program_desc.MutableBlock(root_block_idx); - // insert fill one op for target - // TODO(qiao) add some check to the target. std::string fill_one_op_out = GradVarName(target.Name()); - std::vector target_shape_desc = target.Shape(); - std::vector target_shape; - std::transform(target_shape_desc.begin(), target_shape_desc.end(), - std::back_inserter(target_shape), - [](int64_t dim) { return static_cast(dim); }); + bool is_scalar = target.Shape() == std::vector{1}; + PADDLE_ENFORCE(is_scalar, "target should be scalar"); VLOG(3) << "backward from loss=" << target.Name() << " data_type=" << target.GetDataType(); std::unique_ptr fill_one_op( new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}}, - {{"shape", target_shape}, + {{"shape", std::vector{1}}, {"value", static_cast(1.0)}, {"data_type", target.GetDataType()}})); // infer var type of fill_one_op diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index d485cdf610..2b858f5ea0 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -508,6 +508,7 @@ TEST(Backward, simple_single_op) { op->SetOutput("Out", {"out"}); auto target = f::VarDescBind("out"); + target.SetShape({1}); auto var_to_grad = AppendBackward(program, target, {}); ASSERT_EQ(block->AllOps().size(), 3UL); @@ -544,6 +545,7 @@ TEST(Backward, default_attribute) { op->CheckAttrs(); auto target = f::VarDescBind("out"); + target.SetShape({1}); AppendBackward(program, target, {}); ASSERT_EQ(block->AllOps().size(), 3UL); @@ -581,6 +583,7 @@ TEST(Backward, simple_mult_op) { op3->SetOutput("Out", {"out3"}); auto target = f::VarDescBind("out3"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {}); @@ -670,6 +673,7 @@ TEST(Backward, intermedia_var_no_grad) { op4->SetOutput("Out", {"out4"}); auto target = f::VarDescBind("out4"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {"out3"}); @@ -730,6 +734,7 @@ TEST(Backward, var_no_grad) { op2->SetOutput("Z", {"z2"}); auto target = f::VarDescBind("z2"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {"z1"}); @@ -810,6 +815,7 @@ TEST(Backward, shared_var) { op3->SetOutput("Out", {"out3"}); auto target = f::VarDescBind("out3"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {}); @@ -888,6 +894,7 @@ TEST(Backward, half_backward) { op1->SetOutput("Out", {"out"}); auto target = f::VarDescBind("out"); + target.SetShape({1}); size_t forward_len = block->AllOps().size(); auto var_to_grad = AppendBackward(program, target, {"b"}); f::OpDescBind *fill_op = block->AllOps()[forward_len]; diff --git a/python/paddle/v2/fluid/tests/test_optimizer.py b/python/paddle/v2/fluid/tests/test_optimizer.py index 7b4237e7fd..2459dfd664 100644 --- a/python/paddle/v2/fluid/tests/test_optimizer.py +++ b/python/paddle/v2/fluid/tests/test_optimizer.py @@ -16,14 +16,18 @@ class TestOptimizer(unittest.TestCase): dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") mul_out = block.create_var( dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") block.append_op( type="mul", inputs={"X": mul_x, "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.01) - opts = sgd_optimizer.minimize(mul_out, init_program) + opts = sgd_optimizer.minimize(mean_out, init_program) self.assertEqual(len(opts), 1) sgd_op = opts[0] self.assertEqual(sgd_op.type, "sgd") @@ -44,12 +48,16 @@ class TestOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) global_step = block.create_var( dtype="float32", shape=[1], lod_level=0, name="step") learning_rate = 0.01 sgd_optimizer = optimizer.SGDOptimizer( learning_rate=learning_rate, global_step=global_step) - opts = sgd_optimizer.minimize(mul_out, init_program) + opts = sgd_optimizer.minimize(mean_out, init_program) self.assertEqual(len(opts), 2) sgd_op = opts[0] self.assertEqual(sgd_op.type, "sgd") @@ -90,7 +98,11 @@ class TestMomentumOptimizer(unittest.TestCase): learning_rate = 0.01 momentum_optimizer = self.MockMomentum( learning_rate=learning_rate, momentum=0.2) - params_grads = append_backward_ops(mul_out) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) opts = momentum_optimizer.create_optimization_pass( @@ -132,10 +144,14 @@ class TestMomentumOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 momentum_optimizer = self.MockMomentum( learning_rate=learning_rate, momentum=0.2, use_nesterov=True) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) opts = momentum_optimizer.create_optimization_pass( @@ -186,10 +202,14 @@ class TestAdagradOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 adagrad_optimizer = self.MockAdagrad( learning_rate=learning_rate, epsilon=1.0e-6) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0) opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out, @@ -242,10 +262,14 @@ class TestAdamOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 adam_optimizer = self.MockAdam( learning_rate=learning_rate, beta1=0.9, beta2=0.999) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adam_optimizer.get_accumulators()), 0) opts = adam_optimizer.create_optimization_pass(params_grads, mul_out, @@ -300,10 +324,14 @@ class TestAdamaxOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 adamax_optimizer = self.MockAdamax( learning_rate=learning_rate, beta1=0.9, beta2=0.999) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(adamax_optimizer.get_accumulators()), 0) opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out, @@ -355,10 +383,14 @@ class TestDecayedAdagradOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) learning_rate = 0.01 decayed_adagrad_optimizer = self.MockDecayedAdagrad( learning_rate=learning_rate, decay=0.95, epsilon=1.0e-6) - params_grads = append_backward_ops(mul_out) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0) opts = decayed_adagrad_optimizer.create_optimization_pass( diff --git a/python/paddle/v2/fluid/tests/test_program.py b/python/paddle/v2/fluid/tests/test_program.py index ef2daf6916..e9bcefd215 100644 --- a/python/paddle/v2/fluid/tests/test_program.py +++ b/python/paddle/v2/fluid/tests/test_program.py @@ -1,6 +1,5 @@ import unittest -import paddle.v2.fluid.core as core from paddle.v2.fluid.framework import Program from paddle.v2.fluid.framework import g_main_program @@ -98,21 +97,26 @@ class TestProgram(unittest.TestCase): "Y": add_y}, outputs={"Out": add_out}, attrs={"x_num_col_dims": 1}) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": add_out}, outputs={"Out": mean_out}) self.assertEqual(mul_op.idx, 0) self.assertEqual(add_op.idx, 1) - param_to_grad = prog.append_backward(add_out, set()) + param_to_grad = prog.append_backward(mean_out, set()) def grad_name(name): return name + "@GRAD" - for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out"): + for var_name in ("mul.x", "mul.y", "mul.out", "add.y", "add.out", + "mean.out"): self.assertEqual(param_to_grad[var_name][0], grad_name(var_name)) self.assertEqual(param_to_grad[var_name][1], 0) expect_ops = [ - "mul", "elementwise_add", "fill_constant", "elementwise_add_grad", - "mul_grad" + "mul", "elementwise_add", "mean", "fill_constant", "mean_grad", + "elementwise_add_grad", "mul_grad" ] actual_ops = [] for op in block.ops: diff --git a/python/paddle/v2/fluid/tests/test_regularizer.py b/python/paddle/v2/fluid/tests/test_regularizer.py index f5d1eb3b96..24baf55e90 100644 --- a/python/paddle/v2/fluid/tests/test_regularizer.py +++ b/python/paddle/v2/fluid/tests/test_regularizer.py @@ -29,7 +29,11 @@ class TestL2DecayRegularizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) - params_grads = append_backward_ops(mul_out) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) count_ops = len(block.ops) params_grads = optimizer.append_regularization_ops(params_grads) @@ -62,7 +66,11 @@ class TestL1DecayRegularizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) - params_grads = append_backward_ops(mul_out) + mean_out = block.create_var( + dtype="float32", shape=[1], lod_level=0, name="mean.out") + block.append_op( + type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) + params_grads = append_backward_ops(mean_out) self.assertEqual(len(params_grads), 1) count_ops = len(block.ops) params_grads = optimizer.append_regularization_ops(params_grads)