From 7d2ef02a993a378921a006d3575a802e5e9c5e9d Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 17 Aug 2017 21:18:58 +0800 Subject: [PATCH 01/18] Add ScaleShiftLayer --- doc/api/v2/config/layer.rst | 5 + paddle/gserver/layers/ScaleShiftLayer.cpp | 106 ++++++++++++++++++ paddle/gserver/tests/test_LayerGrad.cpp | 15 +++ python/paddle/trainer/config_parser.py | 14 +++ .../paddle/trainer_config_helpers/layers.py | 37 ++++++ .../tests/configs/file_list.sh | 2 +- .../protostr/test_scale_shift_layer.protostr | 72 ++++++++++++ .../tests/configs/test_scale_shift_layer.py | 11 ++ 8 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 paddle/gserver/layers/ScaleShiftLayer.cpp create mode 100644 python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr create mode 100644 python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index cb330ea5e1..a4a843c610 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -362,6 +362,11 @@ trans .. autoclass:: paddle.v2.layer.trans :noindex: +scale_shift +----------- +.. autoclass:: paddle.v2.layer.scale_shift + :noindex: + Sampling Layers =============== diff --git a/paddle/gserver/layers/ScaleShiftLayer.cpp b/paddle/gserver/layers/ScaleShiftLayer.cpp new file mode 100644 index 0000000000..4f5b1c6225 --- /dev/null +++ b/paddle/gserver/layers/ScaleShiftLayer.cpp @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer does scaling and shifting to the input by appling a slope and + * an intercept which are trainable to the input element-wise. + * + * \f[ + * y = wx + b + * \f] + * + * Here, w is scale and b is offset, which are scalars and trainable. + * + */ + +class ScaleShiftLayer : public Layer { +protected: + std::unique_ptr scale_; + std::unique_ptr offset_; + +public: + explicit ScaleShiftLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(scale_shift, ScaleShiftLayer); + +bool ScaleShiftLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + CHECK_EQ(inputLayers_.size(), 1U); + scale_.reset(new Weight(1, 1, parameters_[0])); + if (biasParameter_.get() != NULL) { + offset_ = std::unique_ptr(new Weight(1, 1, biasParameter_)); + } + return true; +} + +void ScaleShiftLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + resetOutput(inV->getHeight(), inV->getWidth()); + MatrixPtr outV = getOutputValue(); + real scaleValue = scale_->getW()->getElement(0, 0); + outV->mulScalar(*inV, scaleValue); + if (offset_) { + real offsetValue = offset_->getW()->getElement(0, 0); + outV->add(offsetValue); + } +} + +void ScaleShiftLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + + /* Calculate the parameter gradient for the current layer */ + if (scale_->getWGrad()) { + MatrixPtr rowSumMtx; + Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_); + // this_i = scaleDest * this_i + scaleSum * \sum_j b_{ij} * c_{ij} + rowSumMtx->sumOfProducts( + /* b= */ *inV, /* c= */ *outG, /* scaleSum= */ 1, /* scaleDest= */ 0.); + // this_i = scaleDest * this_i + scaleSum * \sum_j b_{ji} + scale_->getWGrad()->sumCols( + /* b= */ *rowSumMtx, /* scaleSum= */ 1., /* scaleDest= */ 1.); + scale_->getParameterPtr()->incUpdate(callback); + } + if (offset_ && offset_->getWGrad()) { + MatrixPtr rowSumMtx; + Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_); + rowSumMtx->sumRows(*outG, 1., 0.); + offset_->getWGrad()->sumCols(*rowSumMtx, 1., 1.); + offset_->getParameterPtr()->incUpdate(callback); + } + + /* Calculate the input layers error */ + if (inG) { + real scaleValue = scale_->getW()->getElement(0, 0); + inG->add(*outG, scaleValue); + } +} + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 0f312b6ca5..65429ebada 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2007,6 +2007,21 @@ TEST(Layer, RowL2NormLayer) { } } +TEST(Layer, ScaleShiftLayer) { + const size_t batchSize = 128; + const size_t size = 512; + TestConfig config; + config.layerConfig.set_type("scale_shift"); + config.layerConfig.set_size(size); + config.biasSize = 1; + config.inputDefs.push_back( + {INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1}); + config.layerConfig.add_inputs(); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index da99e5bd53..8d71629faa 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2232,6 +2232,20 @@ class ClipLayer(LayerBase): self.config.inputs[0].clip_conf.max = max +@config_layer('scale_shift') +class ScaleShiftLayer(LayerBase): + def __init__(self, name, inputs, bias=True, **xargs): + super(ScaleShiftLayer, self).__init__( + name, 'scale_shift', 0, inputs=inputs, **xargs) + config_assert( + len(self.inputs) == 1, + 'ScaleShiftLayer must have one and only one input.') + input_layer = self.get_input_layer(0) + self.set_layer_size(input_layer.size) + self.create_input_parameter(0, 1, [1, 1]) + self.create_bias_parameter(bias, 1) + + # key: cost type # value: cost class g_cost_map = {} diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 1bc55c8696..4c7217024a 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -133,6 +133,7 @@ __all__ = [ 'clip_layer', 'slice_projection', 'kmax_sequence_score_layer', + 'scale_shift_layer', ] @@ -230,6 +231,7 @@ class LayerType(object): CLIP_LAYER = 'clip' KMAX_SEQ_SCORE = 'kmax_seq_score' + SCALE_SHIFT_LAYER = 'scale_shift' @staticmethod def is_layer_type(type_name): @@ -6210,3 +6212,38 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): return LayerOutput( name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size) + + +@wrap_name_default("scale_shift") +@wrap_param_attr_default() +@wrap_bias_attr_default() +def scale_shift_layer(input, name=None, param_attr=None, bias_attr=None): + """ + A layer does scaling and shifting to the input by appling a slope and + an intercept which are trainable to the input element-wise. + .. math:: + + y = w * x + b + + .. code-block:: python + + scale_shift = scale_shift_layer(input=input_layer, bias_attr=False) + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. + :type input: LayerOutput. + :param param_attr: The parameter attribute of scaling. + :type param_attr: ParameterAttribute + :param bias_attr: The parameter attribute of shifting. + :type bias_attr: ParameterAttribute + :return: LayerOutput object. + :rtype: LayerOutput + """ + Layer( + name=name, + type=LayerType.SCALE_SHIFT_LAYER, + inputs=Input(input.name, **param_attr.attr), + bias=ParamAttr.to_bias(bias_attr)) + return LayerOutput( + name, LayerType.SCALE_SHIFT_LAYER, parents=[input], size=input.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index a61beb871a..3860699f6f 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -8,6 +8,6 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer -test_kmax_seq_socre_layer test_seq_select_layers) +test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr new file mode 100644 index 0000000000..efaf20f8a7 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr @@ -0,0 +1,72 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 100 + active_type: "" +} +layers { + name: "__scale_shift_0__" + type: "scale_shift" + size: 100 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___scale_shift_0__.w0" + } + bias_parameter_name: "___scale_shift_0__.wbias" +} +layers { + name: "__scale_shift_1__" + type: "scale_shift" + size: 100 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___scale_shift_1__.w0" + } +} +parameters { + name: "___scale_shift_0__.w0" + size: 1 + initial_mean: 0.0 + initial_std: 1.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: true +} +parameters { + name: "___scale_shift_0__.wbias" + size: 1 + initial_mean: 0.0 + initial_std: 0.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "___scale_shift_1__.w0" + size: 1 + initial_mean: 0.0 + initial_std: 1.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: true +} +input_layer_names: "data" +output_layer_names: "__scale_shift_0__" +output_layer_names: "__scale_shift_1__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "__scale_shift_0__" + layer_names: "__scale_shift_1__" + input_layer_names: "data" + output_layer_names: "__scale_shift_0__" + output_layer_names: "__scale_shift_1__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py new file mode 100644 index 0000000000..818d71f15d --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py @@ -0,0 +1,11 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +data = data_layer(name='data', size=100) + +scale = scale_shift_layer(input=data) + +scale_shift = scale_shift_layer(input=data, bias_attr=False) + +outputs(scale, scale_shift) From f1e553354186c44508565ad89d4b526bdb3a705a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 20 Aug 2017 13:57:26 +0800 Subject: [PATCH 02/18] Rename `Net::AddOp` to `Net::AppendOp` Fix #3582 --- paddle/framework/backward.cc | 9 +++--- paddle/framework/backward_test.cc | 30 +++++++++---------- paddle/framework/pybind.cc | 4 +-- paddle/operators/net_op.h | 7 +++-- paddle/operators/net_op_test.cc | 10 +++---- python/paddle/v2/framework/tests/test_net.py | 10 +++---- .../v2/framework/tests/test_recurrent_op.py | 2 +- 7 files changed, 37 insertions(+), 35 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 9d30887224..bfda18724c 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -110,7 +110,7 @@ static std::unique_ptr BackwardRecursive( dup_output_ops[out].emplace_back(local_op_id); return false; }); - net->AddOp(std::move(bwd)); + net->AppendOp(std::move(bwd)); } // Get unique ID for this method. auto uid = uniq_id++; @@ -163,8 +163,9 @@ static std::unique_ptr BackwardRecursive( // If part of input gradient of that operator is not calculated, fill // zero variables to that input gradient. - net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {{"Src", {prefix}}}, - {{"Dst", {grad_input}}}, {})); + net->AppendOp(OpRegistry::CreateOp("fill_zeros_like", + {{"Src", {prefix}}}, + {{"Dst", {grad_input}}}, {})); } return false; }); @@ -195,7 +196,7 @@ static std::unique_ptr BackwardRecursive( if (net->ops_.empty()) { // Current no aux op is added to network return grad_op; } - net->AddOp(std::move(grad_op)); + net->AppendOp(std::move(grad_op)); } net->SetType("@GENERATED_BACKWARD@"); net->CompleteAddOp(); diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 2c5ec76dfe..b93ab66f2f 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -75,13 +75,13 @@ class FcOp : public operators::NetOp { FcOp(const std::string &type, const VarNameMap &inputs, const VarNameMap &outputs, const AttributeMap &attrs) : NetOp(type, inputs, outputs, attrs) { - AddOp(OpRegistry::CreateOp("mul", - {{"X", {Input("X")}}, {"Y", {Input("W")}}}, - {{"Out", {Output("mul_result")}}}, {})); + AppendOp(OpRegistry::CreateOp("mul", + {{"X", {Input("X")}}, {"Y", {Input("W")}}}, + {{"Out", {Output("mul_result")}}}, {})); auto input_b = Inputs("b"); std::string before_act = "mul_result"; if (input_b.size() != 0) { - AddOp(OpRegistry::CreateOp( + AppendOp(OpRegistry::CreateOp( "rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}}, {{"Out", {Output("add_result")}}}, {})); before_act = "add_result"; @@ -92,8 +92,8 @@ class FcOp : public operators::NetOp { } } - AddOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}}, - {{"Out", {Output("Out")}}}, {})); + AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}}, + {{"Out", {Output("Out")}}}, {})); CompleteAddOp(false); } }; @@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_input_of_network_not_need_grad) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}}, {{"mul_result", {"mul_tmp_0"}}, {"add_result", {"add_tmp_0"}}, {"Out", {"hidden0"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}}, {{"mul_result", {"mul_tmp_1"}}, {"add_result", {"add_tmp_1"}}, @@ -273,10 +273,10 @@ TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_shared_weight) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}}, - {{"Out", {"out"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}}, - {{"Out", {"FinalOut"}}}, {})); + net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}}, + {{"Out", {"out"}}}, {})); + net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}}, + {{"Out", {"FinalOut"}}}, {})); net.CompleteAddOp(); auto bwd = f::Backward(net, {}); @@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"mul_result", {"mul_out1"}}, {"add_result", {"add_out1"}}, {"Out", {"out1"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}}, {{"mul_result", {"mul_out2"}}, {"add_result", {"tmp_out2"}}, {"Out", {"out2"}}}, {})); - net.AddOp(f::OpRegistry::CreateOp( + net.AppendOp(f::OpRegistry::CreateOp( "fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}}, {{"mul_result", {"mul_out3"}}, {"add_result", {"tmp_out3"}}, diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index f0114b9e49..89219a77c3 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -222,8 +222,8 @@ All parameter, weight, gradient are variables in Paddle. retv->SetType("plain_net"); return retv; }) - .def("add_op", [](operators::NetOp &self, - const OperatorBase &op) { self.AddOp(op); }) + .def("append_op", [](operators::NetOp &self, + const OperatorBase &op) { self.AppendOp(op); }) .def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", [](std::shared_ptr &self) { self->CompleteAddOp(); diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 885ac6eeca..3d3f996ef5 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -84,13 +84,14 @@ class NetOp : public framework::OperatorBase { return true; } - void AddOp(const framework::OperatorBase& op) { AddOp(op.Clone()); } + void AppendOp(const framework::OperatorBase& op) { AppendOp(op.Clone()); } /** * @brief Add an operator by ptr */ - void AddOp(std::unique_ptr op) { - PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + void AppendOp(std::unique_ptr op) { + PADDLE_ENFORCE(!add_op_done_, + "Cannot AppendOp when this network is sealed"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); ops_.push_back(std::move(op)); } diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index e9598610c0..99019754a9 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -38,10 +38,10 @@ TEST(OpKernel, all) { auto net = std::make_shared(); ASSERT_NE(net, nullptr); - net->AddOp(std::unique_ptr( + net->AppendOp(std::unique_ptr( new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"Out", {"y"}}}, {}))); - net->AddOp(std::unique_ptr( + net->AppendOp(std::unique_ptr( new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}, {{"Out", {"z"}}}, {}))); @@ -61,7 +61,7 @@ TEST(NetOp, insert_op) { auto op1 = std::unique_ptr( new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, {{"Out", {"y"}}}, {})); - net.AddOp(*op1); + net.AppendOp(*op1); net.InsertOp(0, *op1); ASSERT_EQ(2UL, net.ops_.size()); net.InsertOp(2, std::move(op1)); @@ -70,9 +70,9 @@ TEST(NetOp, insert_op) { TEST(NetOp, Clone) { NetOp net; - net.AddOp( + net.AppendOp( std::unique_ptr(new framework::NOP{"empty", {}, {}, {}})); - net.AddOp(std::unique_ptr( + net.AppendOp(std::unique_ptr( new framework::NOP{"empty2", {}, {}, {}})); net.CompleteAddOp(true); auto new_net_op = net.Clone(); diff --git a/python/paddle/v2/framework/tests/test_net.py b/python/paddle/v2/framework/tests/test_net.py index b42cadd11a..9339cf28da 100644 --- a/python/paddle/v2/framework/tests/test_net.py +++ b/python/paddle/v2/framework/tests/test_net.py @@ -6,8 +6,8 @@ import unittest def fc(X, W, Y): ret_v = core.Net.create() - ret_v.add_op(Operator("mul", X="X", Y="W", Out="pre_activation")) - ret_v.add_op(Operator("sigmoid", X="pre_activation", Y=Y)) + ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation")) + ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y)) ret_v.complete_add_op(True) return ret_v @@ -16,12 +16,12 @@ class TestNet(unittest.TestCase): def test_net_all(self): net = core.Net.create() op1 = Operator("add_two", X="X", Y="Y", Out="Out") - net.add_op(op1) + net.append_op(op1) net2 = core.Net.create() - net2.add_op(fc(X="X", W="w", Y="fc.out")) + net2.append_op(fc(X="X", W="w", Y="fc.out")) net2.complete_add_op(True) - net.add_op(net2) + net.append_op(net2) net.complete_add_op(True) expected = ''' diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/framework/tests/test_recurrent_op.py index 3d4a34d8d7..d6000ab9f9 100644 --- a/python/paddle/v2/framework/tests/test_recurrent_op.py +++ b/python/paddle/v2/framework/tests/test_recurrent_op.py @@ -150,7 +150,7 @@ class TestRecurrentOp(unittest.TestCase): sig_op = Operator("sigmoid", X="sum", Y="h@alias") for op in [x_fc_op, h_fc_op, sum_op, sig_op]: - stepnet.add_op(op) + stepnet.append_op(op) stepnet.complete_add_op(True) self.rnnop.set_stepnet(stepnet) From d525abed955b5dd2e6c711205c11ac6a3bcca789 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 21 Aug 2017 13:43:07 +0800 Subject: [PATCH 03/18] refine random related ops --- paddle/operators/CMakeLists.txt | 4 +- paddle/operators/gaussian_random_op.cc | 35 ++---------- paddle/operators/gaussian_random_op.cu | 41 ++------------ paddle/operators/gaussian_random_op.h | 38 +++++++++++++ paddle/operators/math/math_function.cc | 22 ++++++++ paddle/operators/math/math_function.cu | 36 ++++++++++++ paddle/operators/math/math_function.h | 8 +++ paddle/operators/mul_op.cc | 1 - paddle/operators/uniform_random_op.cc | 39 ++----------- paddle/operators/uniform_random_op.cu | 55 +------------------ paddle/operators/uniform_random_op.h | 38 +++++++++++++ paddle/platform/device_context.cc | 36 ++++++------ paddle/platform/device_context.h | 20 ++++--- .../paddle/v2/framework/tests/CMakeLists.txt | 2 +- .../tests/test_gaussian_random_op.py | 7 +-- .../framework/tests/test_uniform_random_op.py | 7 +-- 16 files changed, 192 insertions(+), 197 deletions(-) create mode 100644 paddle/operators/gaussian_random_op.h create mode 100644 paddle/operators/uniform_random_op.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index a7c89787e4..8f22a5fbc3 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -58,7 +58,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) -op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) +op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu DEPS math_function) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) @@ -67,4 +67,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op - SRCS uniform_random_op.cc uniform_random_op.cu) + SRCS uniform_random_op.cc uniform_random_op.cu DEPS math_function) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index f30bbce958..aba8c6e5cd 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -12,36 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include "paddle/framework/op_registry.h" +#include "paddle/operators/gaussian_random_op.h" namespace paddle { namespace operators { -template -class GaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr("mean"); - float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); - T* data = tensor->mutable_data(context.GetPlace()); - - // TODO(dzh): attribute does not support unsigned int. - // And we need a global random seed configuration. - int seed = context.op_.GetAttr("seed"); - if (seed == 0) { - seed = std::random_device()(); - } - std::mt19937 g(seed); - std::normal_distribution distribution(mean, std); - ssize_t size = framework::product(tensor->dims()); - for (int i = 0; i < size; ++i) { - data[i] = distribution(g); - } - } -}; - class GaussianRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -70,10 +45,6 @@ Use to initialize tensor with gaussian random generator. AddAttr>("dims", "The dimension of random tensor."); AddAttr("mean", "mean value of random.").SetDefault(.0f); AddAttr("std", "minimum value of random value.").SetDefault(1.0f); - AddAttr("seed", - "Random seed of generator." - "0 means use system wide seed") - .SetDefault(0); } }; @@ -83,4 +54,6 @@ Use to initialize tensor with gaussian random generator. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_CPU_KERNEL( + gaussian_random, + ops::GaussianRandomKernel); diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 1340b1e1e9..31be16fdc8 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -12,42 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/platform/dynload/curand.h" -#include "paddle/platform/gpu_info.h" - -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class GaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr("mean"); - float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); - T* data = tensor->mutable_data(context.GetPlace()); - - int seed = context.op_.GetAttr("seed"); - if (seed == 0) { - std::random_device rd; - seed = rd(); - } - curandGenerator_t g; - PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( - &g, CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); - platform::dynload::curandGenerateNormal( - g, data, framework::product(tensor->dims()), mean, std); - } -}; - -} // namespace operators -} // namespace paddle +#include "paddle/operators/gaussian_random_op.h" namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_GPU_KERNEL( + gaussian_random, + ops::GaussianRandomKernel); diff --git a/paddle/operators/gaussian_random_op.h b/paddle/operators/gaussian_random_op.h new file mode 100644 index 0000000000..041390e954 --- /dev/null +++ b/paddle/operators/gaussian_random_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +template +class GaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + T mean = static_cast(context.op_.GetAttr("mean")); + T std = static_cast(context.op_.GetAttr("std")); + auto n = framework::product(tensor->dims()); + + auto* device_context = + const_cast(context.device_context_); + math::RandGaussian(n, mean, std, data, device_context); + } +}; +} +} diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 1e86fc3d16..da59044899 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -109,6 +109,28 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } +template <> +void RandUniform(const int n, const float min, + const float max, float* output, + platform::DeviceContext* context) { + auto* cpu_context = reinterpret_cast(context); + std::uniform_real_distribution distribution(min, max); + for (int i = 0; i < n; i++) { + output[i] = distribution(cpu_context->rand_engine()); + } +} + +template <> +void RandGaussian(const int n, const float mean, + const float std, float* output, + platform::DeviceContext* context) { + auto* cpu_context = reinterpret_cast(context); + std::normal_distribution distribution(mean, std); + for (int i = 0; i < n; i++) { + output[i] = distribution(cpu_context->rand_engine()); + } +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index da40b27c94..5a400d4445 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include +#include +#include #include "paddle/operators/math/math_function.h" namespace paddle { @@ -122,6 +126,38 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } +template <> +void RandUniform(const int n, const float min, + const float max, float* output, + platform::DeviceContext* context) { + auto* cuda_context = reinterpret_cast(context); + thrust::uniform_real_distribution distribution(min, max); + thrust::minstd_rand engine = cuda_context->rand_enigne(); + engine->discard(n); + + thrust::counting_iterator index_sequence_begin(0); + + thrust::transform(thrust::cuda::par.on(cuda_context->stream()), + index_sequence_begin, index_sequence_begin + n, + thrust::device_ptr(output), distribution(engine)); +} + +template <> +void RandGaussian(const int n, const float mean, + const float std, float* output, + platform::DeviceContext* context) { + auto* cuda_context = reinterpret_cast(context); + thrust::normal_distribution distribution(mean, std); + thrust::minstd_rand engine = cuda_context->rand_enigne(); + engine->discard(n); + + thrust::counting_iterator index_sequence_begin(0); + + thrust::transform(thrust::cuda::par.on(cuda_context->stream()), + index_sequence_begin, index_sequence_begin + n, + thrust::device_ptr(output), distribution(engine)); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 155589fadb..ea15e8fd2b 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -77,6 +77,14 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); +template +void RandUniform(const int n, const T min, const T max, T* output, + platform::DeviceContext* context); + +template +void RandGaussian(const int n, const T mean, const T std, T* output, + platform::DeviceContext* context); + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 460e458ca4..173cc3850c 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -13,7 +13,6 @@ limitations under the License. */ #include "paddle/operators/mul_op.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index a0a0d4d914..81487a6bd8 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -12,39 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" +#include "paddle/operators/uniform_random_op.h" namespace paddle { namespace operators { -// It seems that Eigen::Tensor::random in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template -class CPUUniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = - static_cast(context.op_.GetAttr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); - std::uniform_real_distribution dist( - static_cast(context.op_.GetAttr("min")), - static_cast(context.op_.GetAttr("max"))); - for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { - data[i] = dist(engine); - } - } -}; - class UniformRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -72,10 +44,6 @@ Used to initialize tensor with uniform random generator. AddAttr>("dims", "the dimension of random tensor"); AddAttr("min", "Minimum value of uniform random").SetDefault(-1.0f); AddAttr("max", "Maximun value of uniform random").SetDefault(1.0f); - AddAttr("seed", - "Random seed of uniform random. " - "0 means generate a seed by system") - .SetDefault(0); } }; } // namespace operators @@ -83,5 +51,6 @@ Used to initialize tensor with uniform random generator. REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, paddle::operators::UniformRandomOpMaker); -REGISTER_OP_CPU_KERNEL(uniform_random, - paddle::operators::CPUUniformRandomKernel); +REGISTER_OP_CPU_KERNEL( + uniform_random, + paddle::operators::UniformRandomKernel); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 7a243555b6..91368fa73e 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -12,60 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include -#include -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" +#include "paddle/operators/uniform_random_op.h" namespace paddle { namespace operators { -template -struct UniformGenerator { - T min_, max_; - unsigned int seed_; - - __host__ __device__ UniformGenerator(T min, T max, int seed) - : min_(min), max_(max), seed_(seed) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n); - return dist(rng); - } -}; - -// It seems that Eigen::Tensor::random in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template -class GPUUniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = - static_cast(context.op_.GetAttr("seed")); - if (seed == 0) { - std::random_device rd; - seed = rd(); - } - T min = static_cast(context.op_.GetAttr("min")); - T max = static_cast(context.op_.GetAttr("max")); - thrust::counting_iterator index_sequence_begin(0); - ssize_t N = framework::product(tensor->dims()); - thrust::transform(index_sequence_begin, index_sequence_begin + N, - thrust::device_ptr(data), - UniformGenerator(min, max, seed)); - } -}; - -} // namespace operators -} // namespace paddle - REGISTER_OP_GPU_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel); + paddle::operators::GPUUniformRandomKernel< + paddle::platform::GPUPlace, float>); diff --git a/paddle/operators/uniform_random_op.h b/paddle/operators/uniform_random_op.h new file mode 100644 index 0000000000..ec009b025e --- /dev/null +++ b/paddle/operators/uniform_random_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +template +class UniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + T min = static_cast(context.op_.GetAttr("min")); + T max = static_cast(context.op_.GetAttr("max")); + auto n = framework::product(tensor->dims()); + + auto* device_context = + const_cast(context.device_context_); + math::RandUniform(n, min, max, data, device_context); + } +}; +} +} diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index f92c15ae45..fabbb55443 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -25,8 +25,17 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place, int rand_seed) { eigen_device_.reset(new Eigen::DefaultDevice()); + rand_seed_ = rand_seed; +} + +std::minstd_rand& CPUDeviceContext::rand_engine() { + if (!rand_engine_) { + rand_engine_.reset(new std::minstd_rand()); + rand_engine_->seed(rand_seed_); + } + return *(rand_engine_.get()); } Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { @@ -95,7 +104,8 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); } -CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { +CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed) + : place_(place), seed_(seed) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); @@ -114,9 +124,6 @@ CUDADeviceContext::~CUDADeviceContext() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } - if (curand_generator_) { - PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_)); - } eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -150,21 +157,16 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { return cudnn_handle_; } -cudaStream_t CUDADeviceContext::stream() { return stream_; } - -curandGenerator_t CUDADeviceContext::curand_generator() { - if (!curand_generator_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, - CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); - - PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); +thrust::minstd_rand& CPUDeviceContext::rand_engine() { + if (!rand_engine_) { + rand_engine_.reset(new thrust::minstd_rand()); + rand_engine_->seed(rand_seed_); } - return curand_generator_; + return *(rand_engine_.get()); } +cudaStream_t CUDADeviceContext::stream() { return stream_; } + #endif // PADDLE_ONLY_CPU } // namespace platform diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index c5042ae33e..e4de3807cd 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -15,9 +15,10 @@ limitations under the License. */ #include "paddle/platform/place.h" #ifndef PADDLE_ONLY_CPU +#include +#include #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -40,14 +41,18 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace); + explicit CPUDeviceContext(CPUPlace place, int rand_seed = 0); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; + std::minstd_rand& rand_engine(); + Place GetPlace() const override; private: + int rand_seed_; + std::unique_ptr rand_engine_; std::unique_ptr eigen_device_; }; @@ -56,7 +61,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace); + explicit CUDADeviceContext(GPUPlace place, uint64_t rand_seed = 0); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -75,8 +80,7 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - /*! \brief Return curand handle in the device context. */ - curandGenerator_t curand_generator(); + thrust::minstd_rand& CPUDeviceContext::rand_engine(); /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); @@ -85,18 +89,16 @@ class CUDADeviceContext : public DeviceContext { private: GPUPlace place_; - private: std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - private: - uint64_t seed_; + uint64_t rand_seed_; + std::unique_ptr rand_engine_; // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; - curandGenerator_t curand_generator_{nullptr}; // clang-format on }; diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index ce57a07130..b07a65f4d1 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -22,7 +22,7 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_operator SRCS test_operator.py) -# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) +py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py index f95ed70b58..367d21b301 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -17,12 +17,7 @@ class GaussianRandomTest(unittest.TestCase): scope.new_var("Out").get_tensor() op = Operator( - "gaussian_random", - Out="Out", - dims=[1000, 784], - mean=.0, - std=1., - seed=10) + "gaussian_random", Out="Out", dims=[1000, 784], mean=.0, std=1.) op.infer_shape(scope) context = core.DeviceContext.create(place) diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py index c3d2bb44da..95c36a27cf 100644 --- a/python/paddle/v2/framework/tests/test_uniform_random_op.py +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -17,12 +17,7 @@ class UniformRandomTest(unittest.TestCase): scope.new_var("X").get_tensor() op = Operator( - "uniform_random", - Out="X", - dims=[1000, 784], - min=-5.0, - max=10.0, - seed=10) + "uniform_random", Out="X", dims=[1000, 784], min=-5.0, max=10.0) op.infer_shape(scope) ctx = core.DeviceContext.create(place) From 29d8825caf921f5349551a18344503345c7b9969 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 21 Aug 2017 13:43:51 +0800 Subject: [PATCH 04/18] tune relative precision for unit test img_conv2 in test_NetworkCompare.cpp. 1. It's no problem with relative precision 1e-3 when testing several times in my local machine. 2. But the testing failed with 1e-2 in the TeamCity, and only one value's relative precision is over 1e-2. So tune it to 4e-2 --- paddle/gserver/tests/test_NetworkCompare.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index f930c72fde..d36f72360f 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -269,7 +269,8 @@ TEST(Compare, img_conv2) { bool useGpu = FLAGS_use_gpu; double eps = FLAGS_checkgrad_eps; FLAGS_use_gpu = true; - FLAGS_checkgrad_eps = 1e-2; + // Sometimes, this unit test will fail with 1e-2 + FLAGS_checkgrad_eps = 4e-2; compareNetwork(config_file_a, config_file_b); FLAGS_use_gpu = useGpu; FLAGS_checkgrad_eps = eps; From 83abbce8eb750f7e7c844b0959851e901806aa91 Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 21 Aug 2017 14:05:56 +0800 Subject: [PATCH 05/18] Follow comments and refine ScaleShiftLayer --- paddle/gserver/layers/ScaleShiftLayer.cpp | 5 +++-- paddle/gserver/tests/test_LayerGrad.cpp | 4 ++-- python/paddle/trainer_config_helpers/layers.py | 5 +++-- .../protostr/test_scale_shift_layer.protostr | 14 +++++++------- .../tests/configs/test_scale_shift_layer.py | 6 ++---- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/paddle/gserver/layers/ScaleShiftLayer.cpp b/paddle/gserver/layers/ScaleShiftLayer.cpp index 4f5b1c6225..06dcb409f8 100644 --- a/paddle/gserver/layers/ScaleShiftLayer.cpp +++ b/paddle/gserver/layers/ScaleShiftLayer.cpp @@ -17,8 +17,9 @@ limitations under the License. */ namespace paddle { /** - * A layer does scaling and shifting to the input by appling a slope and - * an intercept which are trainable to the input element-wise. + * A layer applies a slope and an intercept to the input element-wise for + * scaling and shifting. Noting that this layer is trainable which differs + * from the SlopeInterceptLayer. * * \f[ * y = wx + b diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 65429ebada..dd2c955e6a 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2008,8 +2008,8 @@ TEST(Layer, RowL2NormLayer) { } TEST(Layer, ScaleShiftLayer) { - const size_t batchSize = 128; - const size_t size = 512; + const size_t batchSize = 16; + const size_t size = 32; TestConfig config; config.layerConfig.set_type("scale_shift"); config.layerConfig.set_size(size); diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 4c7217024a..ec3a87aa36 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -6219,8 +6219,9 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): @wrap_bias_attr_default() def scale_shift_layer(input, name=None, param_attr=None, bias_attr=None): """ - A layer does scaling and shifting to the input by appling a slope and - an intercept which are trainable to the input element-wise. + A layer applies a slope and an intercept to the input element-wise for + scaling and shifting. Noting that this layer is trainable which differs + from the slope_intercept_layer. .. math:: y = w * x + b diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr index efaf20f8a7..35ade126a2 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr @@ -14,7 +14,6 @@ layers { input_layer_name: "data" input_parameter_name: "___scale_shift_0__.w0" } - bias_parameter_name: "___scale_shift_0__.wbias" } layers { name: "__scale_shift_1__" @@ -25,6 +24,7 @@ layers { input_layer_name: "data" input_parameter_name: "___scale_shift_1__.w0" } + bias_parameter_name: "___scale_shift_1__.wbias" } parameters { name: "___scale_shift_0__.w0" @@ -37,24 +37,24 @@ parameters { initial_smart: true } parameters { - name: "___scale_shift_0__.wbias" + name: "___scale_shift_1__.w0" size: 1 initial_mean: 0.0 - initial_std: 0.0 + initial_std: 1.0 dims: 1 dims: 1 initial_strategy: 0 - initial_smart: false + initial_smart: true } parameters { - name: "___scale_shift_1__.w0" + name: "___scale_shift_1__.wbias" size: 1 initial_mean: 0.0 - initial_std: 1.0 + initial_std: 0.0 dims: 1 dims: 1 initial_strategy: 0 - initial_smart: true + initial_smart: false } input_layer_names: "data" output_layer_names: "__scale_shift_0__" diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py index 818d71f15d..dd589116fa 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py +++ b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py @@ -1,11 +1,9 @@ from paddle.trainer_config_helpers import * -settings(batch_size=1000, learning_rate=1e-5) - data = data_layer(name='data', size=100) -scale = scale_shift_layer(input=data) +scale = scale_shift_layer(input=data, bias_attr=False) -scale_shift = scale_shift_layer(input=data, bias_attr=False) +scale_shift = scale_shift_layer(input=data) outputs(scale, scale_shift) From 0af1c4a9feed5a38f34e1ea5a44e3887f702059f Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 21 Aug 2017 14:39:05 +0800 Subject: [PATCH 06/18] Follow comments and refine annotations on ScaleShiftLayer --- paddle/gserver/layers/ScaleShiftLayer.cpp | 8 ++++---- python/paddle/trainer_config_helpers/layers.py | 10 +++++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/paddle/gserver/layers/ScaleShiftLayer.cpp b/paddle/gserver/layers/ScaleShiftLayer.cpp index 06dcb409f8..35fd038ab4 100644 --- a/paddle/gserver/layers/ScaleShiftLayer.cpp +++ b/paddle/gserver/layers/ScaleShiftLayer.cpp @@ -17,15 +17,15 @@ limitations under the License. */ namespace paddle { /** - * A layer applies a slope and an intercept to the input element-wise for - * scaling and shifting. Noting that this layer is trainable which differs - * from the SlopeInterceptLayer. + * A layer applies a linear transformation to each element in each row of + * the input matrix. For each element, the layer first re-scale it and then + * adds a bias to it. * * \f[ * y = wx + b * \f] * - * Here, w is scale and b is offset, which are scalars and trainable. + * Here, w is the scale and b is the bias. Both w and b are trainable scalars. * */ diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index ec3a87aa36..c9e3ded65c 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -6219,9 +6219,13 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): @wrap_bias_attr_default() def scale_shift_layer(input, name=None, param_attr=None, bias_attr=None): """ - A layer applies a slope and an intercept to the input element-wise for - scaling and shifting. Noting that this layer is trainable which differs - from the slope_intercept_layer. + A layer applies a linear transformation to each element in each row of + the input matrix. For each element, the layer first re-scale it and then + adds a bias to it. + + This layer is very like the SlopeInterceptLayer, except the scale and + bias are trainable. + .. math:: y = w * x + b From 7c274dc0a16b77fae0faf527ef02a1f72abad593 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 21 Aug 2017 16:41:22 +0800 Subject: [PATCH 07/18] use curand --- paddle/operators/math/math_function.cc | 9 +++++ paddle/operators/math/math_function.cu | 56 ++++++++++++++++++-------- paddle/operators/math/math_function.h | 8 ++++ paddle/platform/device_context.cc | 15 ++++--- paddle/platform/device_context.h | 6 +-- 5 files changed, 70 insertions(+), 24 deletions(-) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index da59044899..d0b1f8ee48 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -109,6 +109,15 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } +template <> +void Set(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { + auto* cpu_context = reinterpret_cast(context); + framework::EigenVector::Type out(output, n); + out.device(*(cpu_context->eigen_device())) = t.constant(T(alpha)); +} + template <> void RandUniform(const int n, const float min, const float max, float* output, diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 5a400d4445..76bbf790db 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -126,20 +126,48 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } +template <> +void Set(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { + auto* cuda_context = reinterpret_cast(context); + framework::EigenVector::Type out(output, n); + out.device(*(cuda_context->eigen_device())) = t.constant(T(alpha)); +} + +template +__global__ void UniformShift(const int n, const T min, const T max, T* x) { + float scale = max - min; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + x[i] = x[i] * scale + min; + } +} + template <> void RandUniform(const int n, const float min, const float max, float* output, platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast(context); - thrust::uniform_real_distribution distribution(min, max); - thrust::minstd_rand engine = cuda_context->rand_enigne(); - engine->discard(n); - - thrust::counting_iterator index_sequence_begin(0); + PADDLE_ENFORCE( + curandGenerateUniform(cuda_context->curand_generator(), output, n)); + int block = 512; + int grid = (n + block - 1) / block; + UniformShift<<stream()>>>(n, min, max, + output); +} - thrust::transform(thrust::cuda::par.on(cuda_context->stream()), - index_sequence_begin, index_sequence_begin + n, - thrust::device_ptr(output), distribution(engine)); +template +int HandleOddLengthRandGaussian(const int n, const T mean, const T std, + T* output, CUDADeviceContext* context) { + if (n % 2 == 1) { + std::default_random_engine generator; + std::normal_distribution distribution(mean, std); + const T random_value = distribution(generator); + Set(1, random_value, output + (n - 1), context); + return n - 1; + } + return n; } template <> @@ -147,15 +175,11 @@ void RandGaussian(const int n, const float mean, const float std, float* output, platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast(context); - thrust::normal_distribution distribution(mean, std); - thrust::minstd_rand engine = cuda_context->rand_enigne(); - engine->discard(n); - - thrust::counting_iterator index_sequence_begin(0); - thrust::transform(thrust::cuda::par.on(cuda_context->stream()), - index_sequence_begin, index_sequence_begin + n, - thrust::device_ptr(output), distribution(engine)); + const int even_n = + HandleOddLengthRandGaussian(n, mean, std, output, cuda_context); + PADDLE_ENFORCE(curandGenerateNormal(cuda_context->curand_generator(), output, + even_n, mean, std)); } } // namespace math diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index ea15e8fd2b..afe6de7483 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -54,6 +54,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" +#include "paddle/platform/eigen.h" #include "paddle/platform/enforce.h" namespace paddle { @@ -77,6 +78,13 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); +template +void Set(const int n, const T alpha, T* output, + platform::DeviceContext* context) { + framework::EigenVector::Type out(output, n); + out.device(*(context->eigen_device())) = t.constant(T(alpha)); +} + template void RandUniform(const int n, const T min, const T max, T* output, platform::DeviceContext* context); diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index fabbb55443..5fd93555a5 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -157,12 +157,17 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { return cudnn_handle_; } -thrust::minstd_rand& CPUDeviceContext::rand_engine() { - if (!rand_engine_) { - rand_engine_.reset(new thrust::minstd_rand()); - rand_engine_->seed(rand_seed_); +curandGenerator_t CUDADeviceContext::curand_generator() { + if (!curand_generator_) { + SetDeviceId(place_.device); + PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, + CURAND_RNG_PSEUDO_DEFAULT)); + PADDLE_ENFORCE( + dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); + + PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); } - return *(rand_engine_.get()); + return curand_generator_; } cudaStream_t CUDADeviceContext::stream() { return stream_; } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index e4de3807cd..7013343a8d 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -15,10 +15,9 @@ limitations under the License. */ #include "paddle/platform/place.h" #ifndef PADDLE_ONLY_CPU -#include -#include #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" +#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -80,7 +79,8 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - thrust::minstd_rand& CPUDeviceContext::rand_engine(); + /*! \brief Return curand handle in the device context. */ + curandGenerator_t curand_generator(); /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); From 2f47f35b3efec36189a4c6757490b897130d3028 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 21 Aug 2017 09:12:25 +0000 Subject: [PATCH 08/18] fix gpu build error --- paddle/operators/math/CMakeLists.txt | 4 ++-- paddle/operators/math/math_function.cc | 10 +++++----- paddle/operators/math/math_function.cu | 15 ++++++++------- paddle/operators/math/math_function.h | 7 ++----- paddle/operators/uniform_random_op.cu | 9 +++------ paddle/platform/device_context.cc | 10 +++++----- paddle/platform/device_context.h | 6 +++--- 7 files changed, 28 insertions(+), 33 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index ed51d416ed..228f463f2b 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,8 +1,8 @@ if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context eigen3) else() - cc_library(math_function SRCS math_function.cc DEPS cblas device_context) + cc_library(math_function SRCS math_function.cc DEPS cblas device_context eigen3) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index d0b1f8ee48..a098e02f95 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -110,12 +110,12 @@ void matmul(const framework::Tensor& matrix_a, } template <> -void Set(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { +void Set(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { auto* cpu_context = reinterpret_cast(context); - framework::EigenVector::Type out(output, n); - out.device(*(cpu_context->eigen_device())) = t.constant(T(alpha)); + framework::EigenVector::Type out(output, n); + out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha)); } template <> diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 76bbf790db..3ff622f308 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -127,12 +127,12 @@ void matmul(const framework::Tensor& matrix_a, } template <> -void Set(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { +void Set(const int n, const float alpha, + float* output, + platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast(context); - framework::EigenVector::Type out(output, n); - out.device(*(cuda_context->eigen_device())) = t.constant(T(alpha)); + framework::EigenVector::Type out(output, n); + out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha)); } template @@ -159,12 +159,13 @@ void RandUniform(const int n, const float min, template int HandleOddLengthRandGaussian(const int n, const T mean, const T std, - T* output, CUDADeviceContext* context) { + T* output, + platform::CUDADeviceContext* context) { if (n % 2 == 1) { std::default_random_engine generator; std::normal_distribution distribution(mean, std); const T random_value = distribution(generator); - Set(1, random_value, output + (n - 1), context); + Set(1, random_value, output + (n - 1), context); return n - 1; } return n; diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index afe6de7483..6543a1b515 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -52,9 +52,9 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" -#include "paddle/platform/eigen.h" #include "paddle/platform/enforce.h" namespace paddle { @@ -80,10 +80,7 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, template void Set(const int n, const T alpha, T* output, - platform::DeviceContext* context) { - framework::EigenVector::Type out(output, n); - out.device(*(context->eigen_device())) = t.constant(T(alpha)); -} + platform::DeviceContext* context); template void RandUniform(const int n, const T min, const T max, T* output, diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 91368fa73e..1bfffc4778 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -14,9 +14,6 @@ #include "paddle/operators/uniform_random_op.h" -namespace paddle { -namespace operators { - -REGISTER_OP_GPU_KERNEL(uniform_random, - paddle::operators::GPUUniformRandomKernel< - paddle::platform::GPUPlace, float>); +REGISTER_OP_GPU_KERNEL( + uniform_random, + paddle::operators::UniformRandomKernel); diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 5fd93555a5..ad9b4e42f3 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -25,9 +25,9 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place, int rand_seed) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place, int seed) { eigen_device_.reset(new Eigen::DefaultDevice()); - rand_seed_ = rand_seed; + rand_seed_ = seed; } std::minstd_rand& CPUDeviceContext::rand_engine() { @@ -105,7 +105,7 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { } CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed) - : place_(place), seed_(seed) { + : place_(place), rand_seed_(seed) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); @@ -162,8 +162,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() { SetDeviceId(place_.device); PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); + PADDLE_ENFORCE(dynload::curandSetPseudoRandomGeneratorSeed( + curand_generator_, rand_seed_)); PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 7013343a8d..e18f48fef5 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -40,7 +40,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace place, int rand_seed = 0); + explicit CPUDeviceContext(CPUPlace place, int seed = 0); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -60,7 +60,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace place, uint64_t rand_seed = 0); + explicit CUDADeviceContext(GPUPlace place, uint64_t seed = 0); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -93,12 +93,12 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_stream_; uint64_t rand_seed_; - std::unique_ptr rand_engine_; // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; + curandGenerator_t curand_generator_{nullptr}; // clang-format on }; From 08c987d7c086e4176a27f2685712bbb9226e635e Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 21 Aug 2017 17:23:15 +0800 Subject: [PATCH 09/18] use dynload curand --- paddle/operators/gaussian_random_op.h | 4 ++-- paddle/operators/math/math_function.cu | 8 ++++---- paddle/operators/uniform_random_op.h | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/operators/gaussian_random_op.h b/paddle/operators/gaussian_random_op.h index 041390e954..c90b665fe0 100644 --- a/paddle/operators/gaussian_random_op.h +++ b/paddle/operators/gaussian_random_op.h @@ -34,5 +34,5 @@ class GaussianRandomKernel : public framework::OpKernel { math::RandGaussian(n, mean, std, data, device_context); } }; -} -} +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 3ff622f308..908efe9e0f 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -149,8 +149,8 @@ void RandUniform(const int n, const float min, const float max, float* output, platform::DeviceContext* context) { auto* cuda_context = reinterpret_cast(context); - PADDLE_ENFORCE( - curandGenerateUniform(cuda_context->curand_generator(), output, n)); + PADDLE_ENFORCE(platform::dynload::curandGenerateUniform( + cuda_context->curand_generator(), output, n)); int block = 512; int grid = (n + block - 1) / block; UniformShift<<stream()>>>(n, min, max, @@ -179,8 +179,8 @@ void RandGaussian(const int n, const float mean, const int even_n = HandleOddLengthRandGaussian(n, mean, std, output, cuda_context); - PADDLE_ENFORCE(curandGenerateNormal(cuda_context->curand_generator(), output, - even_n, mean, std)); + PADDLE_ENFORCE(platform::dynload::curandGenerateNormal( + cuda_context->curand_generator(), output, even_n, mean, std)); } } // namespace math diff --git a/paddle/operators/uniform_random_op.h b/paddle/operators/uniform_random_op.h index ec009b025e..dffa640f84 100644 --- a/paddle/operators/uniform_random_op.h +++ b/paddle/operators/uniform_random_op.h @@ -34,5 +34,5 @@ class UniformRandomKernel : public framework::OpKernel { math::RandUniform(n, min, max, data, device_context); } }; -} -} +} // namespace operators +} // namespace paddle From b054392e2abebb2a55dabeeb2f12e414bbc2c5af Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 21 Aug 2017 17:46:46 +0800 Subject: [PATCH 10/18] fix gaussion op bug --- paddle/operators/gaussian_random_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index aba8c6e5cd..899f05fa47 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -23,7 +23,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext& context) const override { - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); auto dims = GetAttr>("dims"); PADDLE_ENFORCE(dims.size() > 0UL, "dims can be one int or array. dims must be set."); From 117ce4cbc1a16da1ba8489aaab754aa0ebe5d3ab Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 21 Aug 2017 19:23:42 +0800 Subject: [PATCH 11/18] Change class to struct in GemmFunctor to avoid errors on special compilers --- paddle/function/GemmFunctor.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/function/GemmFunctor.cpp b/paddle/function/GemmFunctor.cpp index dc83278d8e..9e25ee58a1 100644 --- a/paddle/function/GemmFunctor.cpp +++ b/paddle/function/GemmFunctor.cpp @@ -84,7 +84,7 @@ struct BlasGemm { } }; -template class BlasGemm; -template class BlasGemm; +template struct BlasGemm; +template struct BlasGemm; } // namespace paddle From 950dbde56c989f79bace3d53ae38bfae26e84c53 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 21 Aug 2017 08:41:35 -0700 Subject: [PATCH 12/18] fix rowwise add grad op --- paddle/operators/rowwise_add_op.h | 2 +- python/paddle/v2/framework/tests/test_rowwise_add_op.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 232135c38d..771c5d7c0a 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -63,7 +63,7 @@ class RowwiseAddGradKernel : public framework::OpKernel { // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html // colwise add - Eigen::array dims{{1}}; /* dimension to reduce */ + Eigen::array dims{{0}}; /* dimension to reduce */ EigenVector::Flatten(*db).device(place) = OutGrad.sum(dims); } }; diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index 29d72e8500..45d569da29 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -20,7 +20,7 @@ class RowwiseAddGradOpTest(GradientChecker): def test_rowwise_add(self): op = create_op("rowwise_add") inputs = { - "X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"), + "X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"), "b": np.random.uniform(0.1, 1, [10]).astype("float32") } self.check_grad(op, inputs, set(["X", "b"]), "Out") From a75a638fb16ac5b08509c3f185d25ec670d3cb12 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 21 Aug 2017 09:13:19 -0700 Subject: [PATCH 13/18] format Copyright --- paddle/operators/rowwise_add_op.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 771c5d7c0a..1cbd8bb31a 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" From 93539093f4727d4028ca7e592f5fa4f7abdb8bc3 Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Wed, 2 Aug 2017 11:28:25 -0700 Subject: [PATCH 14/18] Allow boot_bias for recurrent group to be static --- paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp index f98bf95064..157b1ab451 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp @@ -184,7 +184,7 @@ public: } void backward(const UpdateCallback& callback) override { - if (biases_) { + if (biases_ && biases_->getWGrad()) { backwardActivation(); biases_->getWGrad()->collectBias(*getOutputGrad(), 1); biases_->getParameterPtr()->incUpdate(callback); From 36e8e725669a20b272f9ace1cf7c9df646c840a3 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 22 Aug 2017 11:40:57 +0800 Subject: [PATCH 15/18] expose random seed to users --- paddle/operators/CMakeLists.txt | 4 +- paddle/operators/gaussian_random_op.cc | 42 ++++++++++--- paddle/operators/gaussian_random_op.cu | 61 +++++++++++++++--- paddle/operators/gaussian_random_op.h | 38 ----------- paddle/operators/math/math_function.cc | 22 ------- paddle/operators/math/math_function.cu | 48 -------------- paddle/operators/math/math_function.h | 8 --- paddle/operators/uniform_random_op.cc | 44 ++++++++++--- paddle/operators/uniform_random_op.cu | 63 ++++++++++++++++--- paddle/operators/uniform_random_op.h | 38 ----------- paddle/platform/device_context.cc | 27 +------- paddle/platform/device_context.h | 15 +---- .../tests/test_gaussian_random_op.py | 7 ++- .../framework/tests/test_uniform_random_op.py | 7 ++- 14 files changed, 196 insertions(+), 228 deletions(-) delete mode 100644 paddle/operators/gaussian_random_op.h delete mode 100644 paddle/operators/uniform_random_op.h diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 8f22a5fbc3..a7c89787e4 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -58,7 +58,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) -op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu DEPS math_function) +op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) @@ -67,4 +67,4 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op - SRCS uniform_random_op.cc uniform_random_op.cu DEPS math_function) + SRCS uniform_random_op.cc uniform_random_op.cu) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 899f05fa47..dcd2237459 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -1,22 +1,44 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gaussian_random_op.h" +#include +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { +template +class CPUGaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + float mean = context.op_.GetAttr("mean"); + float std = context.op_.GetAttr("std"); + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::normal_distribution dist(mean, std); + ssize_t size = framework::product(tensor->dims()); + for (ssize_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } + } +}; + class GaussianRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -43,8 +65,12 @@ Use to initialize tensor with gaussian random generator. )DOC"); AddAttr>("dims", "The dimension of random tensor."); - AddAttr("mean", "mean value of random.").SetDefault(.0f); - AddAttr("std", "minimum value of random value.").SetDefault(1.0f); + AddAttr("mean", "mean of random tensor.").SetDefault(.0f); + AddAttr("std", "std of random tensor.").SetDefault(1.0f); + AddAttr("seed", + "Random seed of generator." + "0 means use system wide seed") + .SetDefault(0); } }; @@ -54,6 +80,4 @@ Use to initialize tensor with gaussian random generator. namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL( - gaussian_random, - ops::GaussianRandomKernel); +REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel); \ No newline at end of file diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 31be16fdc8..1d312e7b5d 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -1,20 +1,65 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/gaussian_random_op.h" +#include +#include +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +struct GaussianGenerator { + T mean_, std_; + unsigned int seed_; + + __host__ __device__ GaussianGenerator(T mean, T std, int seed) + : mean_(mean), std_(std), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::normal_distribution dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +template +class GPUGaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + if (seed == 0) { + std::random_device rd; + seed = rd(); + } + T mean = static_cast(context.op_.GetAttr("mean")); + T std = static_cast(context.op_.GetAttr("std")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + GaussianGenerator(mean, std, seed)); + } +}; + +} // namespace operators +} // namespace paddle -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - gaussian_random, - ops::GaussianRandomKernel); +REGISTER_OP_GPU_KERNEL(gaussian_random, + paddle::operators::GPUGaussianRandomKernel); \ No newline at end of file diff --git a/paddle/operators/gaussian_random_op.h b/paddle/operators/gaussian_random_op.h deleted file mode 100644 index c90b665fe0..0000000000 --- a/paddle/operators/gaussian_random_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { -template -class GaussianRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - T mean = static_cast(context.op_.GetAttr("mean")); - T std = static_cast(context.op_.GetAttr("std")); - auto n = framework::product(tensor->dims()); - - auto* device_context = - const_cast(context.device_context_); - math::RandGaussian(n, mean, std, data, device_context); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index a098e02f95..d9824e5f96 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -118,28 +118,6 @@ void Set(const int n, const float alpha, out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha)); } -template <> -void RandUniform(const int n, const float min, - const float max, float* output, - platform::DeviceContext* context) { - auto* cpu_context = reinterpret_cast(context); - std::uniform_real_distribution distribution(min, max); - for (int i = 0; i < n; i++) { - output[i] = distribution(cpu_context->rand_engine()); - } -} - -template <> -void RandGaussian(const int n, const float mean, - const float std, float* output, - platform::DeviceContext* context) { - auto* cpu_context = reinterpret_cast(context); - std::normal_distribution distribution(mean, std); - for (int i = 0; i < n; i++) { - output[i] = distribution(cpu_context->rand_engine()); - } -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 908efe9e0f..9dff6f05fb 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -135,54 +135,6 @@ void Set(const int n, const float alpha, out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha)); } -template -__global__ void UniformShift(const int n, const T min, const T max, T* x) { - float scale = max - min; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; - i += blockDim.x * gridDim.x) { - x[i] = x[i] * scale + min; - } -} - -template <> -void RandUniform(const int n, const float min, - const float max, float* output, - platform::DeviceContext* context) { - auto* cuda_context = reinterpret_cast(context); - PADDLE_ENFORCE(platform::dynload::curandGenerateUniform( - cuda_context->curand_generator(), output, n)); - int block = 512; - int grid = (n + block - 1) / block; - UniformShift<<stream()>>>(n, min, max, - output); -} - -template -int HandleOddLengthRandGaussian(const int n, const T mean, const T std, - T* output, - platform::CUDADeviceContext* context) { - if (n % 2 == 1) { - std::default_random_engine generator; - std::normal_distribution distribution(mean, std); - const T random_value = distribution(generator); - Set(1, random_value, output + (n - 1), context); - return n - 1; - } - return n; -} - -template <> -void RandGaussian(const int n, const float mean, - const float std, float* output, - platform::DeviceContext* context) { - auto* cuda_context = reinterpret_cast(context); - - const int even_n = - HandleOddLengthRandGaussian(n, mean, std, output, cuda_context); - PADDLE_ENFORCE(platform::dynload::curandGenerateNormal( - cuda_context->curand_generator(), output, even_n, mean, std)); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 6543a1b515..a0e9660564 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -82,14 +82,6 @@ template void Set(const int n, const T alpha, T* output, platform::DeviceContext* context); -template -void RandUniform(const int n, const T min, const T max, T* output, - platform::DeviceContext* context); - -template -void RandGaussian(const int n, const T mean, const T std, T* output, - platform::DeviceContext* context); - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 81487a6bd8..876b3ef557 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -1,22 +1,48 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/uniform_random_op.h" +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class CPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist( + static_cast(context.op_.GetAttr("min")), + static_cast(context.op_.GetAttr("max"))); + ssize_t size = framework::product(tensor->dims()); + for (ssize_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } + } +}; + class UniformRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -38,12 +64,15 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "The output tensor of uniform random op"); AddComment(R"DOC(Uniform random operator. - Used to initialize tensor with uniform random generator. )DOC"); AddAttr>("dims", "the dimension of random tensor"); AddAttr("min", "Minimum value of uniform random").SetDefault(-1.0f); AddAttr("max", "Maximun value of uniform random").SetDefault(1.0f); + AddAttr("seed", + "Random seed of uniform random. " + "0 means generate a seed by system") + .SetDefault(0); } }; } // namespace operators @@ -51,6 +80,5 @@ Used to initialize tensor with uniform random generator. REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, paddle::operators::UniformRandomOpMaker); -REGISTER_OP_CPU_KERNEL( - uniform_random, - paddle::operators::UniformRandomKernel); +REGISTER_OP_CPU_KERNEL(uniform_random, + paddle::operators::CPUUniformRandomKernel); \ No newline at end of file diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index 1bfffc4778..6716b7c7f2 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -1,19 +1,68 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/uniform_random_op.h" +#include +#include +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +struct UniformGenerator { + T min_, max_; + unsigned int seed_; + + __host__ __device__ UniformGenerator(T min, T max, int seed) + : min_(min), max_(max), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(min_, max_); + rng.discard(n); + return dist(rng); + } +}; + +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class GPUUniformRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* tensor = context.Output("Out"); + T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + if (seed == 0) { + std::random_device rd; + seed = rd(); + } + T min = static_cast(context.op_.GetAttr("min")); + T max = static_cast(context.op_.GetAttr("max")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + UniformGenerator(min, max, seed)); + } +}; + +} // namespace operators +} // namespace paddle -REGISTER_OP_GPU_KERNEL( - uniform_random, - paddle::operators::UniformRandomKernel); +REGISTER_OP_GPU_KERNEL(uniform_random, + paddle::operators::GPUUniformRandomKernel); \ No newline at end of file diff --git a/paddle/operators/uniform_random_op.h b/paddle/operators/uniform_random_op.h deleted file mode 100644 index dffa640f84..0000000000 --- a/paddle/operators/uniform_random_op.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" - -namespace paddle { -namespace operators { -template -class UniformRandomKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output("Out"); - T* data = tensor->mutable_data(context.GetPlace()); - T min = static_cast(context.op_.GetAttr("min")); - T max = static_cast(context.op_.GetAttr("max")); - auto n = framework::product(tensor->dims()); - - auto* device_context = - const_cast(context.device_context_); - math::RandUniform(n, min, max, data, device_context); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index ad9b4e42f3..ad212c5b2c 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -25,17 +25,8 @@ CPUDeviceContext::CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } -CPUDeviceContext::CPUDeviceContext(CPUPlace place, int seed) { +CPUDeviceContext::CPUDeviceContext(CPUPlace place) { eigen_device_.reset(new Eigen::DefaultDevice()); - rand_seed_ = seed; -} - -std::minstd_rand& CPUDeviceContext::rand_engine() { - if (!rand_engine_) { - rand_engine_.reset(new std::minstd_rand()); - rand_engine_->seed(rand_seed_); - } - return *(rand_engine_.get()); } Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const { @@ -104,8 +95,7 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); } -CUDADeviceContext::CUDADeviceContext(GPUPlace place, uint64_t seed) - : place_(place), rand_seed_(seed) { +CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); @@ -157,19 +147,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { return cudnn_handle_; } -curandGenerator_t CUDADeviceContext::curand_generator() { - if (!curand_generator_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, - CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE(dynload::curandSetPseudoRandomGeneratorSeed( - curand_generator_, rand_seed_)); - - PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); - } - return curand_generator_; -} - cudaStream_t CUDADeviceContext::stream() { return stream_; } #endif // PADDLE_ONLY_CPU diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index e18f48fef5..11528e1194 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -17,7 +17,6 @@ limitations under the License. */ #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -40,18 +39,14 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace place, int seed = 0); + explicit CPUDeviceContext(CPUPlace place); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; - std::minstd_rand& rand_engine(); - Place GetPlace() const override; private: - int rand_seed_; - std::unique_ptr rand_engine_; std::unique_ptr eigen_device_; }; @@ -60,7 +55,7 @@ class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace place, uint64_t seed = 0); + explicit CUDADeviceContext(GPUPlace place); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -79,9 +74,6 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - /*! \brief Return curand handle in the device context. */ - curandGenerator_t curand_generator(); - /*! \brief Return cuda stream in the device context. */ cudaStream_t stream(); // clang-format on @@ -92,13 +84,10 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; - uint64_t rand_seed_; - // clang-format off cudaStream_t stream_{nullptr}; cudnnHandle_t cudnn_handle_{nullptr}; cublasHandle_t cublas_handle_{nullptr}; - curandGenerator_t curand_generator_{nullptr}; // clang-format on }; diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py index 367d21b301..f95ed70b58 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -17,7 +17,12 @@ class GaussianRandomTest(unittest.TestCase): scope.new_var("Out").get_tensor() op = Operator( - "gaussian_random", Out="Out", dims=[1000, 784], mean=.0, std=1.) + "gaussian_random", + Out="Out", + dims=[1000, 784], + mean=.0, + std=1., + seed=10) op.infer_shape(scope) context = core.DeviceContext.create(place) diff --git a/python/paddle/v2/framework/tests/test_uniform_random_op.py b/python/paddle/v2/framework/tests/test_uniform_random_op.py index 95c36a27cf..c3d2bb44da 100644 --- a/python/paddle/v2/framework/tests/test_uniform_random_op.py +++ b/python/paddle/v2/framework/tests/test_uniform_random_op.py @@ -17,7 +17,12 @@ class UniformRandomTest(unittest.TestCase): scope.new_var("X").get_tensor() op = Operator( - "uniform_random", Out="X", dims=[1000, 784], min=-5.0, max=10.0) + "uniform_random", + Out="X", + dims=[1000, 784], + min=-5.0, + max=10.0, + seed=10) op.infer_shape(scope) ctx = core.DeviceContext.create(place) From 1918ad875980a7b5fb54c207e56d86b4376e2505 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 22 Aug 2017 12:16:30 +0800 Subject: [PATCH 16/18] fix gpu build error --- paddle/operators/math/CMakeLists.txt | 4 ++-- paddle/operators/math/math_function.cc | 9 --------- paddle/operators/math/math_function.cu | 13 ------------- paddle/operators/math/math_function.h | 5 ----- paddle/platform/device_context_test.cc | 2 -- 5 files changed, 2 insertions(+), 31 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 228f463f2b..ed51d416ed 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,8 +1,8 @@ if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context eigen3) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) else() - cc_library(math_function SRCS math_function.cc DEPS cblas device_context eigen3) + cc_library(math_function SRCS math_function.cc DEPS cblas device_context) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index d9824e5f96..1e86fc3d16 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -109,15 +109,6 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } -template <> -void Set(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { - auto* cpu_context = reinterpret_cast(context); - framework::EigenVector::Type out(output, n); - out.device(*(cpu_context->eigen_device())) = out.constant(float(alpha)); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 9dff6f05fb..da40b27c94 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include -#include #include "paddle/operators/math/math_function.h" namespace paddle { @@ -126,15 +122,6 @@ void matmul(const framework::Tensor& matrix_a, matrix_b.data(), beta, matrix_out->data(), context); } -template <> -void Set(const int n, const float alpha, - float* output, - platform::DeviceContext* context) { - auto* cuda_context = reinterpret_cast(context); - framework::EigenVector::Type out(output, n); - out.device(*(cuda_context->eigen_device())) = out.constant(float(alpha)); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index a0e9660564..155589fadb 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -52,7 +52,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include -#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -78,10 +77,6 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); -template -void Set(const int n, const T alpha, T* output, - platform::DeviceContext* context); - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 8b764bdcd9..5883a55272 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -43,8 +43,6 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, cudnn_handle); cublasHandle_t cublas_handle = device_context->cublas_handle(); ASSERT_NE(nullptr, cublas_handle); - curandGenerator_t curand_handle = device_context->curand_generator(); - ASSERT_NE(nullptr, curand_handle); ASSERT_NE(nullptr, device_context->stream()); delete device_context; } From aff90d8ee78be398b2984d63f2eb985f15f430d1 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 22 Aug 2017 04:34:35 +0000 Subject: [PATCH 17/18] fix gpu build error --- paddle/operators/gaussian_random_op.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 1d312e7b5d..018a4bfcb2 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -30,7 +30,7 @@ struct GaussianGenerator { __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed_); - thrust::normal_distribution dist(min_, max_); + thrust::normal_distribution dist(mean_, std_); rng.discard(n); return dist(rng); } @@ -62,4 +62,4 @@ class GPUGaussianRandomKernel : public framework::OpKernel { } // namespace paddle REGISTER_OP_GPU_KERNEL(gaussian_random, - paddle::operators::GPUGaussianRandomKernel); \ No newline at end of file + paddle::operators::GPUGaussianRandomKernel); From 6eab5638f03f49ab1ff3d3a4fc30d870f42a6153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Tue, 22 Aug 2017 13:28:51 +0800 Subject: [PATCH 18/18] Fix remote large update core (#3518) * fix remote large update core * wip * working version * fix style check * fix style check * update style check --- .../gserver/gradientmachines/NeuralNetwork.cpp | 2 +- paddle/parameter/Parameter.h | 5 ++++- paddle/pserver/ParameterClient2.cpp | 16 ++++++++++++++-- paddle/pserver/ParameterClient2.h | 1 + 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index cfa80a8936..26cff3e677 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector& inArgs) { auto mat = dynamic_cast( para->getMat(PARAMETER_VALUE).get()); para->clearGradient(); - mat->clearIndices(); + if (mat) mat->clearIndices(); } } } diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index e31cbc3dee..321f4275d8 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -65,7 +65,10 @@ public: size_t getSize() const { return config_.size(); } bool isFullSize() const { - return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); + if (bufs_[PARAMETER_VALUE]) { + return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); + } + return false; } inline bool useGpu() const { return useGpu_; } diff --git a/paddle/pserver/ParameterClient2.cpp b/paddle/pserver/ParameterClient2.cpp index f7e391f763..54063a809a 100644 --- a/paddle/pserver/ParameterClient2.cpp +++ b/paddle/pserver/ParameterClient2.cpp @@ -65,7 +65,6 @@ void ParameterClient2::initThreads() { LOG(INFO) << "parallel_thread_num dosent need to set"; } syncThreadPool_.reset(new SyncThreadPool(threadNum_)); - startThreads(); } @@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData( request.set_cost(cost); request.set_batch_status(batchStatus); CHECK_EQ(request.blocks_size(), 0); + VLOG(10) << "request: trainer_id: " << request.trainer_id() + << " update_mode" << request.update_mode() + << " send_back_parameter: " << request.send_back_parameter() + << " send_back_parameter_type: " + << request.send_back_parameter_type() + << " num_samples: " << request.num_samples() + << " cost: " << request.cost() + << " batch_status: " << request.batch_status(); } for (const auto& segments : parameterSegments) { const auto it = parameterMap_.find(segments.id); @@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData( CHECK(sendMat != nullptr) << "sendMat is nullptr"; syncThreadPool_->exec([&](int tid, size_t numThreads) { + std::lock_guard guard(sparseAutoGrowthMutex_); const auto& localIndices = prefetchMat->getLocalIndices(); /// num of sparse rows size_t nLocalBlocks = localIndices.size(); uint64_t beginDim = 0; uint64_t endDim = 0; + + // FIXME(typhoonzero): let it resize first + prefetchMat->getLocalRow(nLocalBlocks + 1); + sendMat->getLocalRow(nLocalBlocks + 1); + for (size_t row = 0; row < nLocalBlocks; ++row) { int64_t blockId = localIndices[row]; // local row -> sparse row int serverId = std::abs((blockId + nameHash) % serviceNum_); @@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData( block->set_begin_pos(row * blockSize); /// block len block->set_block_size(endDim - beginDim); - if (sendingPara) { sendJob->parallelInputIovs[serverId].push_back( {sendMat->getLocalRow(row), sizeof(real) * (size_t)blockSize}); diff --git a/paddle/pserver/ParameterClient2.h b/paddle/pserver/ParameterClient2.h index 89b3ddd502..29b9eeacdd 100644 --- a/paddle/pserver/ParameterClient2.h +++ b/paddle/pserver/ParameterClient2.h @@ -583,6 +583,7 @@ protected: #ifndef PADDLE_DISABLE_TIMER uint64_t forwardbackwordTime_; #endif + std::mutex sparseAutoGrowthMutex_; /// map id to parameter used for decoding protobuf data std::unordered_map parameterMap_;