From 15306ffdc39f552a27a3d3d0588ee35701d38f74 Mon Sep 17 00:00:00 2001 From: zhouhanqing <1051910017@qq.com> Date: Wed, 7 Mar 2018 19:11:22 +0800 Subject: [PATCH 01/15] add product reduction for reduce_op --- paddle/fluid/operators/reduce_op.cc | 12 +++++ paddle/fluid/operators/reduce_op.h | 19 +++++++- python/paddle/fluid/layers/nn.py | 47 +++++++++++++++++++ .../fluid/tests/unittests/test_reduce_op.py | 13 +++++ 4 files changed, 90 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/reduce_op.cc b/paddle/fluid/operators/reduce_op.cc index 69e8f8081e..4266636b2a 100644 --- a/paddle/fluid/operators/reduce_op.cc +++ b/paddle/fluid/operators/reduce_op.cc @@ -173,6 +173,15 @@ class ReduceMinOpMaker : public ReduceOpMaker { } }; +class ReduceProdOpMaker : public ReduceOpMaker { + public: + ReduceProdOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : ReduceOpMaker(proto, op_checker) { + SetComment("ReduceProd", "prod"); + AddComment(comment_); + } +}; + } // namespace operators } // namespace paddle @@ -190,6 +199,9 @@ REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad, REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker, reduce_min_grad, ops::ReduceGradOp); +REGISTER_OP(reduce_prod, ops::ReduceOp, ops::ReduceProdOpMaker, + reduce_prod_grad, ops::ReduceGradOp); + #define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL(reduce_type, \ ops::ReduceKernel + void operator()(const DeviceContext& place, X& x, Y& y, const Dim& dim) { + y.device(place) = x.prod(dim); + } +}; + +struct ProdGradFunctor { + template + void operator()(const DeviceContext& place, X& x, Y& y, DX& dx, DY& dy, + const Dim& dim, int size) { + dx.device(place) = dy.broadcast(dim) * y.broadcast(dim) * x.inverse(); + } +}; + template class ReduceKernel : public framework::OpKernel { public: @@ -254,4 +270,5 @@ class ReduceGradKernel : public framework::OpKernel { __macro(reduce_sum, SumFunctor, SumGradFunctor); \ __macro(reduce_mean, MeanFunctor, MeanGradFunctor); \ __macro(reduce_max, MaxFunctor, MaxOrMinGradFunctor); \ - __macro(reduce_min, MinFunctor, MaxOrMinGradFunctor); + __macro(reduce_min, MinFunctor, MaxOrMinGradFunctor); \ + __macro(reduce_prod, ProdFunctor, ProdGradFunctor); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index b4fa530aa6..0d9c0df854 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -49,6 +49,7 @@ __all__ = [ 'reduce_mean', 'reduce_max', 'reduce_min', + 'reduce_prod', 'sequence_first_step', 'sequence_last_step', 'dropout', @@ -2200,6 +2201,52 @@ def reduce_min(input, dim=None, keep_dim=False, name=None): return out +def reduce_prod(input, dim=None, keep_dim=False, name=None): + """ + Computes the product of tensor elements over the given dimension. + + Args: + input (Variable): The input variable which is a Tensor or LoDTensor. + dim (int|None): The dimension along which the product is performed. If + :attr:`None`, multipy all elements of :attr:`input` and return a + Tensor variable with a single element, otherwise must be in the + range :math:`[-rank(input), rank(input))`. If :math:`dim < 0`, + the dimension to reduce is :math:`rank + dim`. + keep_dim (bool|False): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the :attr:`input` unless :attr:`keep_dim` is true. + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The reduced Tensor variable. + + Examples: + .. code-block:: python + + # x is a Tensor variable with following elements: + # [[0.2, 0.3, 0.5, 0.9] + # [0.1, 0.2, 0.6, 0.7]] + # Each example is followed by the correspending output tensor. + fluid.layers.reduce_prod(x) # [0.0002268] + fluid.layers.reduce_prod(x, dim=0) # [0.02, 0.06, 0.3, 0.63] + fluid.layers.reduce_prod(x, dim=-1) # [0.027, 0.0084] + fluid.layers.reduce_prod(x, dim=1, keep_dim=True) # [[0.027], [0.0084]] + """ + helper = LayerHelper('reduce_prod', **locals()) + out = helper.create_tmp_variable(dtype=helper.input_dtype()) + helper.append_op( + type='reduce_prod', + inputs={'X': input}, + outputs={'Out': out}, + attrs={ + 'dim': dim if dim != None else 0, + 'keep_dim': keep_dim, + 'reduce_all': True if dim == None else False + }) + return out + + def split(input, num_or_sections, dim=-1, name=None): """ Split the input tensor into multiple sub-tensors. diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 5e656bddb7..9b0cc3534d 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -70,6 +70,19 @@ class TestMinOp(OpTest): self.check_output() +class TestProdOp(OpTest): + def setUp(self): + self.op_type = "reduce_prod" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")} + self.outputs = {'Out': self.inputs['X'].prod(axis=0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + class TestKeepDimReduce(OpTest): def setUp(self): self.op_type = "reduce_sum" From 9d78971d8bc05fd844448f01ebbc5a8a3d0112a1 Mon Sep 17 00:00:00 2001 From: zhouhanqing <1051910017@qq.com> Date: Fri, 9 Mar 2018 14:07:25 +0800 Subject: [PATCH 02/15] Some comments have been modified. --- paddle/fluid/operators/reduce_op.cc | 2 +- python/paddle/fluid/layers/nn.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/reduce_op.cc b/paddle/fluid/operators/reduce_op.cc index 4266636b2a..7879367830 100644 --- a/paddle/fluid/operators/reduce_op.cc +++ b/paddle/fluid/operators/reduce_op.cc @@ -177,7 +177,7 @@ class ReduceProdOpMaker : public ReduceOpMaker { public: ReduceProdOpMaker(OpProto *proto, OpAttrChecker *op_checker) : ReduceOpMaker(proto, op_checker) { - SetComment("ReduceProd", "prod"); + SetComment("ReduceProd", "production"); AddComment(comment_); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0d9c0df854..a5957304e2 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2215,8 +2215,8 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): keep_dim (bool|False): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. - name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. + name(str|None): A name for this layer(optional). If set None, the + layer will be named automatically. Returns: Variable: The reduced Tensor variable. @@ -2231,7 +2231,8 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_prod(x) # [0.0002268] fluid.layers.reduce_prod(x, dim=0) # [0.02, 0.06, 0.3, 0.63] fluid.layers.reduce_prod(x, dim=-1) # [0.027, 0.0084] - fluid.layers.reduce_prod(x, dim=1, keep_dim=True) # [[0.027], [0.0084]] + fluid.layers.reduce_prod(x, dim=1, + keep_dim=True) # [[0.027], [0.0084]] """ helper = LayerHelper('reduce_prod', **locals()) out = helper.create_tmp_variable(dtype=helper.input_dtype()) From 74523c41f1da2e1ab001ab886ef19275f0a39623 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 9 Mar 2018 20:02:33 +0800 Subject: [PATCH 03/15] enhance regularizer.py --- python/paddle/fluid/regularizer.py | 40 ++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index a29f9a208e..dc641cdd1a 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -13,6 +13,7 @@ # limitations under the License. import framework +from . import core __all__ = [ 'append_regularization_ops', @@ -46,9 +47,9 @@ def append_regularization_ops(parameters_and_grads, regularization=None): regularization_term = None if param.regularizer is not None: # Add variable for regularization term in grad block - regularization_term = param.regularizer(param, grad.block) + regularization_term = param.regularizer(param, grad, grad.block) elif regularization is not None: - regularization_term = regularization(param, grad.block) + regularization_term = regularization(param, grad, grad.block) # If no gradient or no regularization specified, # then we don't need to do anything @@ -82,7 +83,7 @@ class WeightDecayRegularizer(object): def __init__(self): pass - def __call__(self, param, block): + def __call__(self, param, grad, block): """Add corresponding weight decay operations to the network """ raise NotImplementedError() @@ -102,7 +103,7 @@ class L2DecayRegularizer(WeightDecayRegularizer): super(L2DecayRegularizer, self).__init__() self._regularization_coeff = regularization_coeff - def __call__(self, param, block): + def __call__(self, param, grad, block): """Add L2 weight decay ops to network Adds L2 weight decay ops. @@ -117,8 +118,23 @@ class L2DecayRegularizer(WeightDecayRegularizer): """ assert isinstance(param, framework.Parameter) assert isinstance(block, framework.Block) + decay = block.create_var( dtype="float32", shape=param.shape, lod_level=param.lod_level) + + if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + decay = block.create_var( + dtype="float32", + shape=param.shape, + type=core.VarDesc.VarType.SELECTED_ROWS) + block.append_op( + type='lookup_table', + inputs={'W': param, + 'Ids': grad}, + outputs={'Out': decay}, + attrs={'is_sparse': True}) + param = decay + # Append Op to calculate decay block.append_op( type='scale', @@ -141,7 +157,7 @@ class L1DecayRegularizer(WeightDecayRegularizer): super(L1DecayRegularizer, self).__init__() self._regularization_coeff = regularization_coeff - def __call__(self, param, block): + def __call__(self, param, grad, block): """Add L1 weight decay ops to network Adds L1 weight decay ops. @@ -158,6 +174,20 @@ class L1DecayRegularizer(WeightDecayRegularizer): assert isinstance(block, framework.Block) decay = block.create_var( dtype="float32", shape=param.shape, lod_level=param.lod_level) + + if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + # add concat_rows + decay = block.create_var( + dtype="float32", + shape=param.shape, + type=core.VarDesc.VarType.SELECTED_ROWS) + block.append_op( + type='lookup_table', + inputs={'W': param, + 'Ids': grad}, + outputs={'Out': decay}, + attrs={'is_sparse': True}) + # Append sign op block.append_op( type='sign', inputs={"X": param}, outputs={"Out": decay}) From 46ae4075eec45241ddf69b830b7f724f30e63fc7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 12 Mar 2018 14:55:31 +0800 Subject: [PATCH 04/15] Polish ShuffleReader and test --- .../reader/create_shuffle_reader_op.cc | 75 +++++++++++-------- python/paddle/fluid/layers/io.py | 23 +++++- python/paddle/fluid/recordio_writer.py | 3 + .../tests/unittests/test_recordio_reader.py | 13 +++- 4 files changed, 79 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 4dac383110..70e2f587dc 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { @@ -20,43 +23,53 @@ namespace reader { class ShuffleReader : public framework::DecoratedReader { public: - ShuffleReader(ReaderBase* reader, int buffer_size) - : DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) { - buffer_.reserve(buffer_size); + ShuffleReader(ReaderBase* reader, size_t buffer_size, size_t seed = 0) + : DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) { + VLOG(10) << "Create shuffle reader of " << reader_; + if (seed_ == 0) { + std::random_device device; + seed_ = device(); + } + ReadIntoBuffers(); } - void ReadNext(std::vector* out) override; + void ReadNext(std::vector* out) override { + if (iteration_pos_ >= buffer_.size()) { + VLOG(10) << "Resetting shuffle buffer"; + ReadIntoBuffers(); + } + *out = buffer_[iteration_pos_++]; + } - private: - int buffer_size_; - std::vector> buffer_; - size_t iteration_pos_; -}; + bool HasNext() const override { + return iteration_pos_ < buffer_.size() || reader_->HasNext(); + } -void ShuffleReader::ReadNext(std::vector* out) { - if (iteration_pos_ >= buffer_.size()) { - // Reload buffer with new data + private: + void ReadIntoBuffers() { buffer_.clear(); buffer_.reserve(buffer_size_); - for (int i = 0; i < buffer_size_; ++i) { - buffer_.push_back(std::vector()); - reader_->ReadNext(&buffer_.back()); - if (buffer_.back().empty()) { - buffer_.pop_back(); + iteration_pos_ = 0; + PADDLE_ENFORCE(reader_->HasNext()); + for (size_t i = 0; i < buffer_size_; ++i) { + if (!reader_->HasNext()) { break; } + buffer_.emplace_back(); + reader_->ReadNext(&buffer_.back()); } - // TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be - // optimize. - std::random_shuffle(buffer_.begin(), buffer_.end()); - iteration_pos_ = 0; + std::mt19937 g(seed_); + std::shuffle(buffer_.begin(), buffer_.end(), g); + seed_ = g(); // update seed_; + VLOG(10) << "random buffer size = " << buffer_.size(); } - out->clear(); - if (!buffer_.empty()) { - std::swap(*out, buffer_[iteration_pos_++]); - } - // if buffer_ is empty, the 'out' will return as an empty vector. -} + + size_t buffer_size_; + std::vector> buffer_; + + size_t iteration_pos_; + size_t seed_; +}; class CreateShuffleReaderOp : public framework::OperatorBase { public: @@ -67,10 +80,10 @@ class CreateShuffleReaderOp : public framework::OperatorBase { const platform::Place& dev_place) const override { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - auto* out = scope.FindVar(Output("Out")) - ->template GetMutable(); - out->Reset( - new ShuffleReader(underlying_reader.Get(), Attr("buffer_size"))); + auto& var = detail::Ref(scope.FindVar(Output("Out"))); + var.GetMutable()->Reset( + new ShuffleReader(underlying_reader.Get(), + static_cast(Attr("buffer_size")))); } }; diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index f1b2af7020..81dd978949 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,7 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'read_file' + 'read_file', 'create_shuffle_reader' ] @@ -245,6 +245,8 @@ def monkey_patch_reader_methods(reader): reader.eof = eof reader.reset = reset + reader.stop_gradient = True + reader.persistable = True return reader @@ -285,6 +287,25 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): startup_var) +def __create_decorated_reader__(op_type, reader, attrs): + var_name = unique_name(op_type) + startup_blk = default_startup_program().current_block() + startup_var = startup_blk.create_var(name=var_name) + startup_blk.append_op( + type=op_type, + inputs={'UnderlyingReader': reader}, + outputs={'Out': [startup_var]}, + attrs=attrs) + startup_var.persistable = True + return _copy_reader_var_(default_main_program().current_block(), + startup_var) + + +def create_shuffle_reader(reader, buffer_size): + return __create_decorated_reader__('create_shuffle_reader', reader, + {'buffer_size': int(buffer_size)}) + + def read_file(file_obj): helper = LayerHelper('read_file') out = [ diff --git a/python/paddle/fluid/recordio_writer.py b/python/paddle/fluid/recordio_writer.py index 9735df8c06..5accaacd53 100644 --- a/python/paddle/fluid/recordio_writer.py +++ b/python/paddle/fluid/recordio_writer.py @@ -36,6 +36,7 @@ def convert_reader_to_recordio_file( feed_order=None): if feed_order is None: feed_order = feeder.feed_names + counter = 0 with create_recordio_writer(filename, compressor, max_num_records) as writer: for batch in reader_creator(): @@ -43,3 +44,5 @@ def convert_reader_to_recordio_file( for each in feed_order: writer.append_tensor(res[each]) writer.complete_append_tensor() + counter += 1 + return counter diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index d249742bd3..cdebda5b7d 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -31,10 +31,10 @@ class TestRecordIO(unittest.TestCase): name='label', shape=[1], dtype='int64'), ], place=fluid.CPUPlace()) - fluid.recordio_writer.convert_reader_to_recordio_file( + self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file( './mnist.recordio', reader, feeder) - def test_main(self): + def test_main(self, decorator_callback=None): # use new program with fluid.program_guard(fluid.Program(), fluid.Program()): data_file = fluid.layers.open_recordio_file( @@ -42,6 +42,8 @@ class TestRecordIO(unittest.TestCase): shapes=[[-1, 784], [-1, 1]], lod_levels=[0, 0], dtypes=['float32', 'int64']) + if decorator_callback is not None: + data_file = decorator_callback(data_file) img, label = fluid.layers.read_file(data_file) hidden = fluid.layers.fc(input=img, size=100, act='tanh') @@ -56,9 +58,14 @@ class TestRecordIO(unittest.TestCase): avg_loss_np = [] # train a pass + batch_id = 0 while not data_file.eof(): tmp, = exe.run(fetch_list=[avg_loss]) avg_loss_np.append(tmp) + batch_id += 1 data_file.reset() - + self.assertEqual(batch_id, self.num_batches) self.assertLess(avg_loss_np[-1], avg_loss_np[0]) + + def test_shuffle_reader(self): + self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200)) From 2ea4a5d96c0d134c84651e691510f90c8b19f0fa Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 12 Mar 2018 15:39:31 +0800 Subject: [PATCH 05/15] Polish double buffer reader --- .../reader/create_double_buffer_reader_op.cc | 79 ++++++++++++++----- python/paddle/fluid/layers/io.py | 10 ++- .../tests/unittests/test_recordio_reader.py | 14 +++- 3 files changed, 81 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index ba08ea12e2..ca947fff43 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -24,11 +24,16 @@ static constexpr size_t kDoubleBufferSize = 2; class DoubleBufferReader : public framework::DecoratedReader { public: - explicit DoubleBufferReader(ReaderBase* reader) - : DecoratedReader(reader), - buffer_(framework::MakeChannel>( - kDoubleBufferSize)) { - std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); + explicit DoubleBufferReader( + ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) + : DecoratedReader(reader), place_(target_place) { + start_thread(); + } + + void start_thread() { + buffer_ = framework::MakeChannel>( + kDoubleBufferSize); + std::thread prefetch([this] { PrefetchThreadFunc(); }); prefetch.detach(); } @@ -43,6 +48,8 @@ class DoubleBufferReader : public framework::DecoratedReader { void PrefetchThreadFunc(); framework::Channel>* buffer_; + platform::Place place_; + mutable std::vector local_buffer_; }; class CreateDoubleBufferReaderOp : public framework::OperatorBase { @@ -56,7 +63,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ->Get(); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new DoubleBufferReader(underlying_reader.Get())); + + auto place_str = Attr("place"); + platform::Place place; + if (place_str == "CPU") { + place = platform::CPUPlace(); + } else { + std::istringstream sin(place_str); + sin.seekg(std::string("CUDA:").size(), std::ios::beg); + size_t num; + sin >> num; + place = platform::CUDAPlace(static_cast(num)); + } + + out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); } }; @@ -71,44 +91,65 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { It launches another thread to execute the 'underlying reader' asynchronously, which prevents reading process from blocking subsequent training. )DOC"); + std::unordered_set enum_range; + constexpr size_t kMaxCUDADevs = 128; + for (size_t i = 0; i < kMaxCUDADevs; ++i) { + enum_range.insert(string::Sprintf("CUDA:%d", i)); + } + enum_range.insert("CPU"); + AddAttr("place", "The double buffer place, default is CPU") + .SetDefault("CPU") + .InEnum({enum_range}); } }; void DoubleBufferReader::ReadNext(std::vector* out) { out->clear(); - buffer_->Receive(out); + if (local_buffer_.empty()) { + buffer_->Receive(out); + } else { + *out = local_buffer_; + local_buffer_.clear(); + } } void DoubleBufferReader::ReInit() { reader_->ReInit(); buffer_->Close(); - // The existing prefetch thread will terminate for the buffer_ is closed. - buffer_ = framework::MakeChannel>( - kDoubleBufferSize); - std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); - prefetch.detach(); + start_thread(); } void DoubleBufferReader::PrefetchThreadFunc() { VLOG(5) << "A new prefetch thread starts."; - while (true) { + while (reader_->HasNext()) { std::vector batch; reader_->ReadNext(&batch); - if (batch.empty()) { - // EOF - buffer_->Close(); - VLOG(5) << "Reached the end of the file. The prefetch thread terminates."; - break; + if (platform::is_gpu_place(place_)) { + std::vector gpu_batch; + gpu_batch.resize(batch.size()); + for (size_t i = 0; i < batch.size(); ++i) { + framework::TensorCopy(batch[i], place_, &gpu_batch[i]); + gpu_batch[i].set_lod(batch[i].lod()); + } } + if (!buffer_->Send(&batch)) { VLOG(5) << "WARNING: The double buffer channel has been closed. The " "prefetch thread terminates."; break; } } + buffer_->Close(); } -bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); } +bool DoubleBufferReader::HasNext() const { + if (local_buffer_.empty()) { + bool ok = buffer_->Receive(&local_buffer_); + return ok; + } else { + return true; + } +} } // namespace reader } // namespace operators diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 81dd978949..9c91f395e7 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,7 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'read_file', 'create_shuffle_reader' + 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader' ] @@ -306,6 +306,14 @@ def create_shuffle_reader(reader, buffer_size): {'buffer_size': int(buffer_size)}) +def create_double_buffer_reader(reader, place=None): + attrs = dict() + if place is not None: + attrs['place'] = str(place).upper() + return __create_decorated_reader__('create_double_buffer_reader', reader, + attrs) + + def read_file(file_obj): helper = LayerHelper('read_file') out = [ diff --git a/python/paddle/fluid/tests/unittests/test_recordio_reader.py b/python/paddle/fluid/tests/unittests/test_recordio_reader.py index cdebda5b7d..24a0074d9b 100644 --- a/python/paddle/fluid/tests/unittests/test_recordio_reader.py +++ b/python/paddle/fluid/tests/unittests/test_recordio_reader.py @@ -13,9 +13,10 @@ # limitations under the License. import unittest + import paddle.fluid as fluid -import paddle.v2.dataset.mnist as mnist import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist class TestRecordIO(unittest.TestCase): @@ -53,7 +54,12 @@ class TestRecordIO(unittest.TestCase): fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss) - exe = fluid.Executor(fluid.CPUPlace()) + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) avg_loss_np = [] @@ -69,3 +75,7 @@ class TestRecordIO(unittest.TestCase): def test_shuffle_reader(self): self.test_main(decorator_callback=lambda reader: fluid.layers.create_shuffle_reader(reader, buffer_size=200)) + + def test_double_buffer_reader(self): + self.test_main(decorator_callback=lambda reader: fluid.layers.create_double_buffer_reader(reader, + place='cuda:0' if fluid.core.is_compiled_with_cuda() else 'cpu')) From 225efa671fd1b234e67752ad9a1cd4aecdffe58b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 12 Mar 2018 16:10:19 +0800 Subject: [PATCH 06/15] Remove dims in base class --- paddle/fluid/framework/operator.cc | 20 ++------------ paddle/fluid/framework/reader.cc | 10 +------ paddle/fluid/framework/reader.h | 26 ++----------------- .../reader/create_random_data_generator_op.cc | 5 ++-- .../reader/create_recordio_file_reader_op.cc | 10 +++---- 5 files changed, 12 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ac6289c5ab..49f8cd5f90 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -442,15 +442,7 @@ class RuntimeInferShapeContext : public InferShapeContext { } std::vector GetRepeatedDims(const std::string& name) const override { - Variable* var = scope_.FindVar(name); - if (var->IsType()) { - return var->Get().shapes(); - } else { - PADDLE_THROW( - "Only ReaderHolder support 'GetRepeatedDims', but Variable %s's " - "type_id is %s.", - name, var->Type().name()); - } + PADDLE_THROW("Only compile time support this method"); } void SetDim(const std::string& name, const DDim& dim) override { @@ -467,15 +459,7 @@ class RuntimeInferShapeContext : public InferShapeContext { void SetRepeatedDims(const std::string& name, const std::vector& dims) override { - Variable* var = scope_.FindVar(name); - if (var->IsType()) { - var->GetMutable()->set_shapes(dims); - } else { - PADDLE_THROW( - "Only ReaderHolder support 'SetRepeatedDims', but Variable %s's " - "type_id is %s.", - name, var->Type().name()); - } + PADDLE_THROW("Only compile time support this method"); } proto::VarType::Type GetVarType(const std::string& name) const override { diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 91879d6d45..31f686151e 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -16,14 +16,6 @@ namespace paddle { namespace framework { - -DDim ReaderBase::shape(size_t idx) const { - PADDLE_ENFORCE_LT( - idx, shapes_.size(), - "Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, - shapes_.size()); - return shapes_[idx]; -} - +ReaderBase::~ReaderBase() {} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index e281c9b13f..2d8d30fc66 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -22,34 +22,18 @@ namespace framework { class ReaderBase { public: - explicit ReaderBase(const std::vector& shapes) : shapes_(shapes) { - PADDLE_ENFORCE(!shapes_.empty()); - } virtual void ReadNext(std::vector* out) = 0; virtual void ReInit() = 0; - DDim shape(size_t idx) const; - std::vector shapes() const { return shapes_; } - void set_shapes(const std::vector& shapes) { shapes_ = shapes; } - virtual bool HasNext() const = 0; - virtual ~ReaderBase() {} - - protected: - std::vector shapes_; -}; - -class FileReader : public ReaderBase { - public: - explicit FileReader(const std::vector& shapes) : ReaderBase(shapes) {} + virtual ~ReaderBase(); }; class DecoratedReader : public ReaderBase { public: - explicit DecoratedReader(ReaderBase* reader) - : ReaderBase(reader->shapes()), reader_(reader) { + explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); } @@ -72,12 +56,6 @@ class ReaderHolder { void ReadNext(std::vector* out) { reader_->ReadNext(out); } void ReInit() { reader_->ReInit(); } - DDim shape(size_t idx) const { return reader_->shape(idx); } - std::vector shapes() const { return reader_->shapes(); } - void set_shapes(const std::vector& shapes) { - reader_->set_shapes(shapes); - } - bool HasNext() const { return reader_->HasNext(); } private: diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index e62f952d0e..95d8674c08 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -19,11 +19,11 @@ namespace operators { namespace reader { template -class RandomDataGenerator : public framework::FileReader { +class RandomDataGenerator : public framework::ReaderBase { public: RandomDataGenerator(const std::vector& shapes, float min, float max) - : FileReader(shapes), min_(min), max_(max) { + : framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) { PADDLE_ENFORCE_LE( min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); unsigned int seed = std::random_device()(); @@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader { float max_; std::minstd_rand engine_; std::uniform_real_distribution dist_; + std::vector shapes_; }; template diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index c3eb247bbe..4992eb8617 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -18,11 +18,10 @@ namespace paddle { namespace operators { namespace reader { -class RecordIOFileReader : public framework::FileReader { +class RecordIOFileReader : public framework::ReaderBase { public: - RecordIOFileReader(const std::string& filename, - const std::vector& shapes) - : FileReader(shapes), + explicit RecordIOFileReader(const std::string& filename) + : ReaderBase(), scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) {} @@ -54,12 +53,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { int(shape_concat.size()), "The accumulate of all ranks should be equal to the " "shape concat's length."); - std::vector shapes = RestoreShapes(shape_concat, ranks); std::string filename = Attr("filename"); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RecordIOFileReader(filename, shapes)); + out->Reset(new RecordIOFileReader(filename)); } }; From f9974a4a12de337559cb1d6494c4d1f7656d52e9 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 13 Mar 2018 14:44:19 +0800 Subject: [PATCH 07/15] Make double_buffer reader async --- .../reader/create_double_buffer_reader_op.cc | 59 +++++++++++++------ 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index ca947fff43..706f6fd592 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -24,15 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2; class DoubleBufferReader : public framework::DecoratedReader { public: + struct Item { + Item() : ctx_(nullptr) {} + + std::vector payloads_; + platform::DeviceContext* ctx_; + }; + explicit DoubleBufferReader( ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) : DecoratedReader(reader), place_(target_place) { + for (size_t i = 0; i < kDoubleBufferSize; ++i) { + if (platform::is_gpu_place(place_)) { +#ifdef PADDLE_WITH_CUDA + ctxs_.emplace_back(new platform::CUDADeviceContext( + boost::get(place_))); +#else +#endif + } + } + start_thread(); } void start_thread() { - buffer_ = framework::MakeChannel>( - kDoubleBufferSize); + buffer_ = framework::MakeChannel(kDoubleBufferSize); std::thread prefetch([this] { PrefetchThreadFunc(); }); prefetch.detach(); } @@ -47,9 +63,10 @@ class DoubleBufferReader : public framework::DecoratedReader { private: void PrefetchThreadFunc(); - framework::Channel>* buffer_; + framework::Channel* buffer_; platform::Place place_; - mutable std::vector local_buffer_; + std::vector> ctxs_; + mutable Item local_buffer_; }; class CreateDoubleBufferReaderOp : public framework::OperatorBase { @@ -104,12 +121,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { }; void DoubleBufferReader::ReadNext(std::vector* out) { - out->clear(); - if (local_buffer_.empty()) { - buffer_->Receive(out); - } else { - *out = local_buffer_; - local_buffer_.clear(); + if (local_buffer_.payloads_.empty()) { + buffer_->Receive(&local_buffer_); + } + + *out = local_buffer_.payloads_; + local_buffer_.payloads_.clear(); + if (local_buffer_.ctx_) { + local_buffer_.ctx_->Wait(); } } @@ -121,16 +140,22 @@ void DoubleBufferReader::ReInit() { void DoubleBufferReader::PrefetchThreadFunc() { VLOG(5) << "A new prefetch thread starts."; + size_t gpu_ctx_offset = 0; while (reader_->HasNext()) { - std::vector batch; - reader_->ReadNext(&batch); + Item batch; + reader_->ReadNext(&batch.payloads_); if (platform::is_gpu_place(place_)) { std::vector gpu_batch; - gpu_batch.resize(batch.size()); - for (size_t i = 0; i < batch.size(); ++i) { - framework::TensorCopy(batch[i], place_, &gpu_batch[i]); - gpu_batch[i].set_lod(batch[i].lod()); + auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++]; + gpu_ctx_offset %= this->ctxs_.size(); + gpu_batch.resize(batch.payloads_.size()); + for (size_t i = 0; i < batch.payloads_.size(); ++i) { + framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx, + &gpu_batch[i]); + gpu_batch[i].set_lod(batch.payloads_[i].lod()); } + batch.ctx_ = gpu_ctx.get(); + std::swap(gpu_batch, batch.payloads_); } if (!buffer_->Send(&batch)) { @@ -143,7 +168,7 @@ void DoubleBufferReader::PrefetchThreadFunc() { } bool DoubleBufferReader::HasNext() const { - if (local_buffer_.empty()) { + if (local_buffer_.payloads_.empty()) { bool ok = buffer_->Receive(&local_buffer_); return ok; } else { From 164f2382afe6ded95c95f4fb731a1d932d578026 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 13 Mar 2018 17:56:53 +0800 Subject: [PATCH 08/15] Polish code --- paddle/fluid/framework/reader.cc | 40 +------------------ paddle/fluid/framework/reader.h | 25 +----------- .../reader/create_double_buffer_reader_op.cc | 1 - .../reader/create_recordio_file_reader_op.cc | 4 +- 4 files changed, 6 insertions(+), 64 deletions(-) diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index c3fb657a3a..fa00c08e0d 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -18,45 +18,9 @@ namespace paddle { namespace framework { ReaderBase::~ReaderBase() {} -std::vector> ReaderBase::SplitReader( - const platform::PlaceList &places) { - std::vector> readers; +FileReader::FileReader(const std::vector &dims) : dims_(dims) {} - auto mutex = std::make_shared(); - for (size_t i = 0; i < places.size(); ++i) { - readers.emplace_back(new ThreadSafeReader(this, mutex)); - } - - return readers; -} - -void ThreadSafeReader::ReadNext(std::vector *out) { - std::lock_guard guard(*mutex_); - reader_->ReadNext(out); -} - -void ThreadSafeReader::ReInit() { - std::lock_guard guard(*mutex_); - reader_->ReInit(); -} - -bool ThreadSafeReader::HasNext() const { - std::lock_guard guard(*mutex_); - return reader_->HasNext(); -} - -std::vector> ThreadSafeReader::SplitReader( - const platform::PlaceList &places) { - std::vector> readers; - for (size_t i = 0; i < places.size(); ++i) { - readers.emplace_back(new ThreadSafeReader(reader_, mutex_)); - } - return readers; -} - -FileReaderBase::FileReaderBase(const std::vector &dims) : dims_(dims) {} - -void FileReaderBase::ReadNext(std::vector *out) { +void FileReader::ReadNext(std::vector *out) { ReadNextImpl(out); PADDLE_ENFORCE_EQ(out->size(), dims_.size()); for (size_t i = 0; i < dims_.size(); ++i) { diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 8989bddd10..3573b99bec 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -33,9 +33,6 @@ class ReaderBase { virtual bool HasNext() const = 0; - virtual std::vector> SplitReader( - const platform::PlaceList& places); - virtual ~ReaderBase(); }; @@ -53,27 +50,9 @@ class DecoratedReader : public ReaderBase { ReaderBase* reader_; }; -class ThreadSafeReader : public DecoratedReader { - public: - ThreadSafeReader(ReaderBase* reader, const std::shared_ptr& mutex) - : DecoratedReader(reader), mutex_(mutex) {} - - void ReadNext(std::vector* out) override; - - void ReInit() override; - - bool HasNext() const override; - - std::vector> SplitReader( - const platform::PlaceList& places) override; - - private: - std::shared_ptr mutex_; -}; - -class FileReaderBase : public ReaderBase { +class FileReader : public ReaderBase { public: - explicit FileReaderBase(const std::vector& dims); + explicit FileReader(const std::vector& dims); void ReadNext(std::vector* out) override; diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 706f6fd592..d0de092947 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -39,7 +39,6 @@ class DoubleBufferReader : public framework::DecoratedReader { #ifdef PADDLE_WITH_CUDA ctxs_.emplace_back(new platform::CUDADeviceContext( boost::get(place_))); -#else #endif } } diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 819e09a369..c4aa29c720 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -18,11 +18,11 @@ namespace paddle { namespace operators { namespace reader { -class RecordIOFileReader : public framework::FileReaderBase { +class RecordIOFileReader : public framework::FileReader { public: explicit RecordIOFileReader(const std::string& filename, const std::vector& dims) - : FileReaderBase(dims), + : FileReader(dims), scanner_(filename), dev_ctx_(*platform::DeviceContextPool::Instance().Get( platform::CPUPlace())) {} From 686a3ad6014239c776cc9f8ba0a07161fa8d115b Mon Sep 17 00:00:00 2001 From: ranqiu Date: Fri, 9 Mar 2018 17:00:28 +0800 Subject: [PATCH 09/15] Add api doc std --- doc/fluid/dev/api_doc_std_cn.md | 219 ++++++++++++++++++++++++++++++++ doc/fluid/dev/src/fc.py | 80 ++++++++++++ 2 files changed, 299 insertions(+) create mode 100644 doc/fluid/dev/api_doc_std_cn.md create mode 100644 doc/fluid/dev/src/fc.py diff --git a/doc/fluid/dev/api_doc_std_cn.md b/doc/fluid/dev/api_doc_std_cn.md new file mode 100644 index 0000000000..1d57550f64 --- /dev/null +++ b/doc/fluid/dev/api_doc_std_cn.md @@ -0,0 +1,219 @@ +# API注释撰写标准 + +- [API注释模块](#API注释模块) +- [格式及示例](#格式及示例) +- [完整示例](#完整示例) + + +## API注释模块 + +API文档须包含以下几个模块(排列顺序为文档撰写顺序): + +- Python API Definition + + API的代码定义。 + +- Function Description + + API的功能描述。描述该API的含义、作用或对输入所做的操作,及参考文献和对应链接(如果有),必要时给出公式,并解释公式中关键变量的含义。 + +- Args Description + + API参数介绍。按代码定义中的参数顺序逐个介绍,介绍内容包含数据类型、默认值(如果有)、含义等。 + +- Returns + + API返回值介绍。介绍返回值含义,必要时给出对应的形状。若返回值为包含多个参数的tuple,则按顺序逐个介绍各参数。 + +- Raises(如果有) + + 可能抛出的异常或错误及可能的产生原因,当可能抛出多种异常或错误时应分条列出。 + +- Note(如果有) + + 注意事项。当有多条注意事项时,应分条列出。 + +- Examples + + API的使用示例。 + + +## 格式及示例 + +API文档各模块格式及示例如下(以下以fc为例进行说明): + +- Python API Definition + + - 格式: + + [Python API Definition] + + - 示例 + + ``` + fc(input, + size, + num_flatten_dims=1, + param_attr=None, + bias_attr=None, + act=None, + name=None, + main_program=None, + startup_program=None) + ``` + +- Function Description + + - 格式 + + 本模块应包含以下内容(排列顺序为文档撰写顺序): + + [Function Description] + + [Formula] + + [Symbols' Descriptions if necessary] + + [References if necessary] + + - 示例 + + [Function Description] + + ``` + **Fully Connected Layer** + + The fully connected layer can take multiple tensors as its inputs. It + creates a variable called weights for each input tensor, which represents + a fully connected weight matrix from each input unit to each output unit. + The fully connected layer multiplies each input tensor with its coresponding + weight to produce an output Tensor. If multiple input tensors are given, + the results of multiple multiplications will be sumed up. If bias_attr is + not None, a bias variable will be created and added to the output. Finally, + if activation is not None, it will be applied to the output as well. + ``` + + [Formula] + + ``` + This process can be formulated as follows: + + .. math:: + + Out = Act({\sum_{i=0}^{N-1}X_iW_i + b}) + ``` + + [Symbols' Descriptions if necessary] + + ``` + In the above equation: + + * :math:`N`: Number of the input. + * :math:`X_i`: The input tensor. + * :math:`W`: The weights created by this layer. + * :math:`b`: The bias parameter created by this layer (if needed). + * :math:`Act`: The activation function. + * :math:`Out`: The output tensor. + ``` + + [References if necessary] + + 因fc没有必要列出的参考文献,故该内容省略。其他情况下需明确给出对应的参考文献和对应连接,以 layer_norm 为例: + + ``` + Refer to `Layer Normalization `_ for more details. + ``` + + +- Args Description + + - 格式 + + \[Arg's Name\][(Data Type, Default Value)][Description] + + - 示例 + + fc的部分参数注释如下: + + ``` + Args: + input (Tensor): The input tensor(s) of the layer. + param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable + parameters/weights of this layer. + name (str, default None): The name of this layer. + ``` + +- Returns + + - 格式 + + [Name][Shape] + + - 示例 + + ``` + Returns: + A tensor variable storing the transformation result. + ``` + + 当返回值为包含多个参数的tuple时,应按顺序逐个介绍各参数,以dynamic_lstm为例: + + ``` + Returns: + A tuple containing: + The hidden state of LSTM whose shape is (T X D). + The cell state of LSTM whose shape is (T X D). + ``` + +- Raises + + - 格式 + + [Exception Type][Condition] + + - 示例 + + ``` + Raises: + ValueError: If the rank of the input is less than 2. + ``` + +- Note + + - 格式 + + [Note] + + - 示例 + + fc没有注意事项,故该模块省略不写。其他情况应明确给出,若有多条注意事项,须分条列出,以scaled\_dot\_product\_attention为例: + + ``` + Note: + 1. When num_heads > 1, three linear projections are learned respectively + to map input queries, keys and values into queries', keys' and values'. + queries', keys' and values' have the same shapes with queries, keys + and values. + 2. When num_heads == 1, scaled_dot_product_attention has no learnable + parameters. + ``` + +- Examples + + - 格式 + + \[Python Code Snipper] + + - 示例 + + ``` + Examples: + .. code-block:: python + + data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") + fc = fluid.layers.fc(input=data, size=1000, act="tanh") + ``` + +## 完整示例 + +fc 的完整注释见[示例](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/fluid/dev/src/fc.py)。 diff --git a/doc/fluid/dev/src/fc.py b/doc/fluid/dev/src/fc.py new file mode 100644 index 0000000000..40f3c7fd3e --- /dev/null +++ b/doc/fluid/dev/src/fc.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def fc(input, + size, + num_flatten_dims=1, + param_attr=None, + bias_attr=None, + act=None, + name=None): + """ + **Fully Connected Layer** + + The fully connected layer can take multiple tensors as its inputs. It + creates a variable called weights for each input tensor, which represents + a fully connected weight matrix from each input unit to each output unit. + The fully connected layer multiplies each input tensor with its coresponding + weight to produce an output Tensor. If multiple input tensors are given, + the results of multiple multiplications will be sumed up. If bias_attr is + not None, a bias variable will be created and added to the output. Finally, + if activation is not None, it will be applied to the output as well. + + This process can be formulated as follows: + + .. math:: + + Out = Act({\sum_{i=0}^{N-1}X_iW_i + b}) + + In the above equation: + + * :math:`N`: Number of the input. + * :math:`X_i`: The input tensor. + * :math:`W`: The weights created by this layer. + * :math:`b`: The bias parameter created by this layer (if needed). + * :math:`Act`: The activation function. + * :math:`Out`: The output tensor. + + Args: + input (Tensor|list of Tensor): The input tensor(s) to this layer. + size(int): The number of output units in the fully connected layer. + num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than + two dimensions. If this happens, the multidimensional tensor will first be flattened + into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input + tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1) + dimensions will be flatten to form the first dimension of the final matrix (height of + the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to + form the second dimension of the final matrix (width of the matrix). For example, suppose + `X` is a 6-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. + Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. + param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable + parameters/weights of this layer. + bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias + parameter of this layer. If set None, no bias will be added to the output units. + act (str, default None): Activation to be applied to the output of this layer. + name (str, default None): The name of this layer. + + Returns: + A tensor variable storing the transformation result. + + Raises: + ValueError: If rank of the input tensor is less than 2. + + Examples: + .. code-block:: python + + data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") + fc = fluid.layers.fc(input=data, size=1000, act="tanh") + """ From 14fe40aaa6e19009f6f0836826e367f2ae5c1dee Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 14 Mar 2018 10:29:39 +0800 Subject: [PATCH 10/15] Refine/nccl (#9009) * "Refine nccl op" * "refine code " * "refine nccl code" --- paddle/fluid/operators/nccl_op.cc | 92 +++++++++--------- paddle/fluid/operators/nccl_op.cu.cc | 139 +++++++++------------------ 2 files changed, 89 insertions(+), 142 deletions(-) diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl_op.cc index 329656d26d..5e4ed886b1 100644 --- a/paddle/fluid/operators/nccl_op.cc +++ b/paddle/fluid/operators/nccl_op.cc @@ -104,19 +104,38 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { " Input(Communicator) of AllReduce op input should not be NULL"); PADDLE_ENFORCE(ctx->HasOutput("Out"), " Output(Out) of AllReduce op output should not be NULL"); - - auto x_dims = ctx->GetInputsDim("X"); - std::string reduction = ctx->Attrs().Get("reduction"); PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || reduction == "ncclMin" || reduction == "ncclMax"), "invalid reduction."); + auto x_dims = ctx->GetInputsDim("X"); ctx->SetOutputsDim("Out", x_dims); ctx->ShareLoD("X", /*->*/ "Out"); } }; +// AllReduceOp +class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of AllReduce op"); + AddInput("Communicator", "Communicator for communicating between gpus"); + AddOutput("Out", "The output of AllReduce op"); + AddAttr("reduction", + "(string, default 'ncclSum') " + "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.") + .SetDefault("ncclSum"); + AddComment(R"DOC( +NCCLAllReduce Operator. + +AllReduce the input tensors. + +)DOC"); + } +}; + // ReduceOp class NCCLReduceOp : public framework::OperatorWithKernel { public: @@ -143,50 +162,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel { } }; -// BcastOp -class NCCLBcastOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - " Input(X) of Bcast op input should not be NULL"); - PADDLE_ENFORCE(ctx->HasInput("Communicator"), - " Input(Communicator) of Bcast op input should not be NULL"); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - " Output(Out) of Bcast op output should not be NULL"); - - int root = ctx->Attrs().Get("root"); - PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set."); - - auto x_dims = ctx->GetInputsDim("X"); - ctx->SetOutputsDim("Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); - } -}; - -// AllreduceOp -class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { - public: - NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input of AllReduce op"); - AddInput("Communicator", "Communicator for communicating between gpus"); - AddOutput("Out", "The output of AllReduce op"); - AddAttr("reduction", - "(string, default 'ncclSum') " - "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.") - .SetDefault("ncclSum"); - AddComment(R"DOC( -NCCLAllReduce Operator. - -AllReduce the input tensors. - -)DOC"); - } -}; - // ReduceOp class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -213,6 +188,29 @@ Reduce the tensors. } }; +// BcastOp +class NCCLBcastOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + " Input(X) of Bcast op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasInput("Communicator"), + " Input(Communicator) of Bcast op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + " Output(Out) of Bcast op output should not be NULL"); + + int root = ctx->Attrs().Get("root"); + PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set."); + + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + // BcastOp class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker { public: diff --git a/paddle/fluid/operators/nccl_op.cu.cc b/paddle/fluid/operators/nccl_op.cu.cc index 683a520e99..4d83a70e73 100644 --- a/paddle/fluid/operators/nccl_op.cu.cc +++ b/paddle/fluid/operators/nccl_op.cu.cc @@ -43,13 +43,12 @@ class NCCLAllReduceKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - - auto ins = ctx.MultiInput("X"); - auto outs = ctx.MultiOutput("Out"); - + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto* comm = ctx.Input("Communicator"); std::string reduction = ctx.Attr("reduction"); - ncclRedOp_t reduction_op_ = ncclSum; + ncclRedOp_t reduction_op_ = ncclSum; if (reduction == "ncclMin") { reduction_op_ = ncclMin; } else if (reduction == "ncclMax") { @@ -61,30 +60,19 @@ class NCCLAllReduceKernel : public framework::OpKernel { } else { PADDLE_THROW("Invalid reduction. default ncclSum."); } - - auto* comm = ctx.Input("Communicator"); - - auto stream = ctx.cuda_device_context().stream(); - // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(gpu_id); - - for (size_t i = 0; i < ins.size(); ++i) { - VLOG(1) << "gpu : " - << " invoke allreduce. send " << ins[i]->numel() << " recv " - << outs[i]->numel(); - - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - ins[i]->data(), outs[i]->mutable_data(ctx.GetPlace()), - outs[i]->numel(), NCCLTypeWrapper::type, reduction_op_, - comm->comms().at(idx), stream)); - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - - VLOG(1) << "gpu : " - << " finished allreduce. send " << ins[i]->numel() << " recv " - << outs[i]->numel(); - } + VLOG(3) << "gpu : " + << " invoke allreduce. send " << x->numel() << " recv " + << out->numel(); + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + x->data(), out->mutable_data(ctx.GetPlace()), out->numel(), + NCCLTypeWrapper::type, reduction_op_, comm->comms().at(idx), + ctx.cuda_device_context().stream())); + VLOG(3) << "gpu : " + << " finished allreduce. send " << x->numel() << " recv " + << out->numel(); } }; @@ -94,13 +82,13 @@ class NCCLReduceKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - - auto ins = ctx.MultiInput("X"); // x0, x1, x2 - auto outs = ctx.MultiOutput("Out"); - + auto x = ctx.Input("X"); // x0, x1, x2 + auto out = ctx.Output("Out"); + auto* comm = ctx.Input("Communicator"); + int root = ctx.Attr("root"); std::string reduction = ctx.Attr("reduction"); - ncclRedOp_t reduction_op_ = ncclSum; + ncclRedOp_t reduction_op_ = ncclSum; if (reduction == "ncclMin") { reduction_op_ = ncclMin; } else if (reduction == "ncclMax") { @@ -112,40 +100,21 @@ class NCCLReduceKernel : public framework::OpKernel { } else { PADDLE_THROW("Invalid reduction. default ncclSum."); } - - int root = ctx.Attr("root"); - auto* comm = ctx.Input("Communicator"); - - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(gpu_id); - - auto ins_names = ctx.Inputs("X"); - std::hash hasher; - for (size_t i = 0; i < ins.size(); ++i) { - if (root == platform::kInvalidGPUId) { - root = hasher(ins_names[i]) % comm->comms().size(); - } - T* recvbuffer = nullptr; - if (root == gpu_id) { - recvbuffer = outs[i]->mutable_data(ctx.GetPlace()); - } - - VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send " - << ins[i]->numel() << " recv " << outs[i]->numel(); - - PADDLE_ENFORCE(platform::dynload::ncclReduce( - ins[i]->data(), recvbuffer, ins[i]->numel(), - NCCLTypeWrapper::type, reduction_op_, root, comm->comms().at(idx), - stream)); - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - - VLOG(1) << "gpu : " << gpu_id << " finished reduce. send " - << ins[i]->numel() << " recv " << outs[i]->numel(); + T* recvbuffer = nullptr; + if (root == gpu_id) { + recvbuffer = out->mutable_data(ctx.GetPlace()); } + VLOG(3) << "gpu : " << gpu_id << " invoke reduce. send " << x->numel() + << " recv " << out->numel(); + PADDLE_ENFORCE(platform::dynload::ncclReduce( + x->data(), recvbuffer, x->numel(), NCCLTypeWrapper::type, + reduction_op_, root, comm->comms().at(idx), + ctx.cuda_device_context().stream())); + VLOG(3) << "gpu : " << gpu_id << " finished reduce. send " << x->numel() + << " recv " << out->numel(); } }; @@ -155,47 +124,27 @@ class NCCLBcastKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - int root = ctx.Attr("root"); - auto* comm = ctx.Input("Communicator"); - - auto stream = reinterpret_cast( - ctx.device_context()) - .stream(); // device id int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(gpu_id); - if (idx == root) { - auto ins = ctx.MultiInput("X"); - for (size_t i = 0; i < ins.size(); ++i) { - VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send " - << ins[i]->numel(); - - VLOG(1) << " before ncclBcast"; - PADDLE_ENFORCE(platform::dynload::ncclBcast( - (void*)ins[i]->data(), ins[i]->numel(), NCCLTypeWrapper::type, - root, comm->comms().at(idx), stream)); - VLOG(1) << " after ncclBcast"; - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - - VLOG(1) << "gpu : " << gpu_id << " finished Bcast."; - } + auto* x = ctx.Input("X"); + VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. send " << x->numel(); + PADDLE_ENFORCE(platform::dynload::ncclBcast( + (void*)x->data(), x->numel(), NCCLTypeWrapper::type, root, + comm->comms().at(idx), ctx.cuda_device_context().stream())); + VLOG(3) << "gpu : " << gpu_id << " finished Bcast."; } else { - auto outs = ctx.MultiOutput("Out"); - for (size_t i = 0; i < outs.size(); ++i) { - VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer " - << framework::product(outs[i]->dims()); - - PADDLE_ENFORCE(platform::dynload::ncclBcast( - outs[i]->mutable_data(ctx.GetPlace()), outs[i]->numel(), - NCCLTypeWrapper::type, root, comm->comms().at(idx), stream)); - PADDLE_ENFORCE(cudaStreamSynchronize(stream)); - - VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv " - << outs[i]->numel(); - } + auto* out = ctx.Output("Out"); + VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. recv buffer " + << framework::product(out->dims()); + PADDLE_ENFORCE(platform::dynload::ncclBcast( + out->mutable_data(ctx.GetPlace()), out->numel(), + NCCLTypeWrapper::type, root, comm->comms().at(idx), + ctx.cuda_device_context().stream())); + VLOG(3) << "gpu : " << gpu_id << " finished Bcast. recv " << out->numel(); } } }; From d13ce3587559c5553f05d75789269a0dff49734f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Wed, 14 Mar 2018 10:38:01 +0800 Subject: [PATCH 11/15] Feature/send recv can now retry (#9027) --- paddle/fluid/operators/detail/grpc_client.cc | 18 ++++++++-- paddle/fluid/operators/detail/grpc_client.h | 36 +++++++++++++------ paddle/fluid/operators/detail/grpc_server.cc | 21 +++++++---- paddle/fluid/operators/detail/grpc_server.h | 2 +- .../fluid/operators/detail/sendrecvop_utils.h | 1 + paddle/fluid/operators/listen_and_serv_op.cc | 4 +-- paddle/fluid/operators/send_op.cc | 6 ++++ python/paddle/fluid/distribute_transpiler.py | 20 +++++++++-- 8 files changed, 83 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 7266f32764..ddeeebec58 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -97,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { +void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); @@ -108,8 +108,18 @@ bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, (void*)s); req_count_++; +} - return true; +void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { + const auto ch = GetChannel(ep); + FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); + s->Prepare(time_out); + + sendrecv::VariableMessage req; + req.set_varname(FETCH_BARRIER_MESSAGE); + auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + req_count_++; } bool RPCClient::Wait() { @@ -154,7 +164,7 @@ bool RPCClient::Proceed() { PADDLE_ENFORCE(tag); // TODO(gongwb): add more retries. - ClientBase* c = static_cast(tag); + BaseProcessor* c = static_cast(tag); if (!c->status_.ok()) { LOG(ERROR) << "proc param error:" << c->var_h_.String() << " grpc error:" << c->status_.error_message(); @@ -174,6 +184,8 @@ std::shared_ptr RPCClient::GetChannel(const std::string& ep) { } grpc::ChannelArguments args; + args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000); + args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE); args.SetMaxSendMessageSize(std::numeric_limits::max()); args.SetMaxReceiveMessageSize(std::numeric_limits::max()); diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index 669838810d..f520367dd9 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -52,14 +52,14 @@ struct VarHandle { void ProcGetResponse(const VarHandle& var_h, const sendrecv::VariableMessage& msg); -class ClientBase { +class BaseProcessor { public: - explicit ClientBase(std::shared_ptr ch) { + explicit BaseProcessor(std::shared_ptr ch) { stub_ = sendrecv::SendRecvService::NewStub(ch); context_ = NULL; } - virtual ~ClientBase() {} + virtual ~BaseProcessor() {} virtual void Prepare(const VarHandle& var_info, int64_t time_out) { context_.reset(new grpc::ClientContext()); @@ -91,9 +91,10 @@ class ClientBase { typedef std::function RequestSendCallBack; -class SendProcessor : public ClientBase { +class SendProcessor : public BaseProcessor { public: - explicit SendProcessor(std::shared_ptr ch) : ClientBase(ch) {} + explicit SendProcessor(std::shared_ptr ch) + : BaseProcessor(ch) {} virtual ~SendProcessor() {} @@ -110,9 +111,10 @@ class SendProcessor : public ClientBase { typedef std::function RequestGetCallBack; -class GetProcessor : public ClientBase { +class GetProcessor : public BaseProcessor { public: - explicit GetProcessor(std::shared_ptr ch) : ClientBase(ch) {} + explicit GetProcessor(std::shared_ptr ch) + : BaseProcessor(ch) {} virtual ~GetProcessor() {} @@ -126,10 +128,10 @@ class GetProcessor : public ClientBase { RequestGetCallBack response_call_back_ = ProcGetResponse; }; -class BatchBarrierProcessor : public ClientBase { +class BatchBarrierProcessor : public BaseProcessor { public: explicit BatchBarrierProcessor(std::shared_ptr ch) - : ClientBase(ch) {} + : BaseProcessor(ch) {} virtual ~BatchBarrierProcessor() {} @@ -137,6 +139,17 @@ class BatchBarrierProcessor : public ClientBase { sendrecv::VoidMessage reply_; }; +class FetchBarrierProcessor : public BaseProcessor { + public: + explicit FetchBarrierProcessor(std::shared_ptr ch) + : BaseProcessor(ch) {} + + virtual ~FetchBarrierProcessor() {} + + virtual void Process() {} + sendrecv::VariableMessage reply_; +}; + class RPCClient { public: bool AsyncSendVariable(const std::string& ep, @@ -151,7 +164,10 @@ class RPCClient { const std::string& var_name, int64_t time_out = 600 * 1000); - bool AsyncSendBatchBarrier(const std::string& ep, + void AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out = 600 * 1000); + + void AsyncSendFetchBarrier(const std::string& ep, int64_t time_out = 600 * 1000); bool Wait(); diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 2a56751661..8fff430cc4 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -84,7 +84,7 @@ class RequestGet final : public RequestBase { explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, grpc::ServerCompletionQueue* cq, framework::Scope* scope, const platform::DeviceContext* dev_ctx, - SimpleBlockQueue* queue) + SimpleBlockQueue* queue) : RequestBase(service, cq), responder_(&ctx_), scope_(scope), @@ -101,11 +101,16 @@ class RequestGet final : public RequestBase { // proc request. std::string var_name = request_.varname(); auto* var = scope_->FindVar(var_name); - SerializeToMessage(var_name, var, *dev_ctx_, &reply_); + if (var_name != FETCH_BARRIER_MESSAGE) { + SerializeToMessage(var_name, var, *dev_ctx_, &reply_); + } // TODO(gongwb): check var's info. responder_.Finish(reply_, grpc::Status::OK, this); status_ = FINISH; - queue_->Push('c'); + MessageWithName msg_with_name = + // request name reply + std::make_pair(var_name, std::move(reply_)); + queue_->Push(msg_with_name); } protected: @@ -114,12 +119,16 @@ class RequestGet final : public RequestBase { ServerAsyncResponseWriter responder_; framework::Scope* scope_; const platform::DeviceContext* dev_ctx_; - SimpleBlockQueue* queue_; + SimpleBlockQueue* queue_; }; void AsyncGRPCServer::WaitClientGet(int count) { - for (int i = 0; i < count; ++i) { - var_get_queue_.Pop(); + int fetch_barriers = 0; + while (fetch_barriers < count) { + auto msg = var_get_queue_.Pop(); + if (msg.first == FETCH_BARRIER_MESSAGE) { + fetch_barriers++; + } } } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index e9402ff6aa..b6666bcf96 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -77,7 +77,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { const platform::DeviceContext *dev_ctx_; // received variable from RPC, operators fetch variable from this queue. SimpleBlockQueue var_recv_queue_; - SimpleBlockQueue var_get_queue_; + SimpleBlockQueue var_get_queue_; // condition of the sub program std::mutex barrier_mutex_; diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index 5208091e54..4fa6aefd3e 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -32,6 +32,7 @@ namespace detail { #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" +#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" typedef void (*DestroyCallback)(void*); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 8e9923c87c..4253300788 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -128,8 +128,8 @@ class ListenAndServOp : public framework::OperatorBase { } } if (exit_flag) { - rpc_service_->ShutDown(); rpc_service_->SetCond(1); + rpc_service_->ShutDown(); break; } try { @@ -148,7 +148,7 @@ class ListenAndServOp : public framework::OperatorBase { } rpc_service_->SetCond(1); // FIXME(typhoonzero): use another condition to sync wait clients get. - rpc_service_->WaitClientGet(ins.size()); + rpc_service_->WaitClientGet(fan_in); sparse_vars.clear(); } // while(true) } diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 8fdd08eae6..443f40e803 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -88,6 +88,12 @@ class SendOp : public framework::OperatorBase { rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } PADDLE_ENFORCE(rpc_client->Wait()); + // tell pservers that current trainer have called fetch + for (auto& ep : endpoints) { + VLOG(3) << "send fetch barrier, ep: " << ep; + rpc_client->AsyncSendFetchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); } } }; diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index bb2ce4d45d..3d3a6c116e 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -250,6 +250,8 @@ class DistributeTranspiler: def get_trainer_program(self): # remove optimize ops and add a send op to main_program self.program.global_block().delete_ops(self.optimize_ops) + # FIXME(typhoonzero): serialize once will fix error occurs when clone. + self.program.__str__() return self.program def get_pserver_program(self, endpoint): @@ -309,7 +311,8 @@ class DistributeTranspiler: for _, opt_op in enumerate(opt_op_on_pserver): if ufind.is_connected(op, opt_op): if self._is_opt_op(op): - self._append_pserver_ops(optimize_block, op, endpoint) + self._append_pserver_ops(optimize_block, op, endpoint, + default_main_program()) else: self._append_pserver_non_opt_ops(optimize_block, op) break @@ -520,7 +523,8 @@ class DistributeTranspiler: orig_var_name = varname[:suff_idx] return orig_var_name - def _append_pserver_ops(self, optimize_block, opt_op, endpoint): + def _append_pserver_ops(self, optimize_block, opt_op, endpoint, + origin_program): program = optimize_block.program pserver_block = program.global_block() new_inputs = dict() @@ -576,7 +580,17 @@ class DistributeTranspiler: elif key == "LearningRate": # leraning rate variable has already be created by non-optimize op, # don't create it once again. - new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] + lr_varname = opt_op.input(key)[0] + if pserver_block.vars.has_key(lr_varname): + new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] + else: + origin_var = origin_program.global_block().vars[lr_varname] + tmpvar = pserver_block.create_var( + name=origin_var.name, + persistable=origin_var.persistable, + dtype=origin_var.dtype, + shape=origin_var.shape) + new_inputs[key] = tmpvar for key in opt_op.input_names: new_shape = None From a78b7602185bf370bf2619b91e6a39afeb1d36e3 Mon Sep 17 00:00:00 2001 From: ranqiu Date: Wed, 14 Mar 2018 10:44:21 +0800 Subject: [PATCH 12/15] Refine api_doc_std_cn --- doc/fluid/dev/api_doc_std_cn.md | 8 ++++---- doc/fluid/dev/src/fc.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/fluid/dev/api_doc_std_cn.md b/doc/fluid/dev/api_doc_std_cn.md index 1d57550f64..9e9e77177f 100644 --- a/doc/fluid/dev/api_doc_std_cn.md +++ b/doc/fluid/dev/api_doc_std_cn.md @@ -40,7 +40,7 @@ API文档须包含以下几个模块(排列顺序为文档撰写顺序): ## 格式及示例 -API文档各模块格式及示例如下(以下以fc为例进行说明): +API文档须使用rst格式撰写,该格式详情请参考[链接](http://sphinx-doc-zh.readthedocs.io/en/latest/rest.html)。API文档各模块的内容格式及示例如下(以下以fc为例进行说明): - Python API Definition @@ -137,7 +137,7 @@ API文档各模块格式及示例如下(以下以fc为例进行说明): ``` Args: - input (Tensor): The input tensor(s) of the layer. + input (Variable|list of Variable): This layer's input tensor(s) which is at least 2-dimensional. param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable parameters/weights of this layer. name (str, default None): The name of this layer. @@ -186,7 +186,7 @@ API文档各模块格式及示例如下(以下以fc为例进行说明): - 示例 - fc没有注意事项,故该模块省略不写。其他情况应明确给出,若有多条注意事项,须分条列出,以scaled\_dot\_product\_attention为例: + fc没有注意事项,故该模块省略不写。如有注意事项应明确给出,当有多条注意事项,须分条列出,以scaled\_dot\_product\_attention为例: ``` Note: @@ -216,4 +216,4 @@ API文档各模块格式及示例如下(以下以fc为例进行说明): ## 完整示例 -fc 的完整注释见[示例](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/fluid/dev/src/fc.py)。 +fc 的完整注释见[示例](src/fc.py)。 diff --git a/doc/fluid/dev/src/fc.py b/doc/fluid/dev/src/fc.py index 40f3c7fd3e..14a3c4cd01 100644 --- a/doc/fluid/dev/src/fc.py +++ b/doc/fluid/dev/src/fc.py @@ -48,8 +48,8 @@ def fc(input, * :math:`Out`: The output tensor. Args: - input (Tensor|list of Tensor): The input tensor(s) to this layer. - size(int): The number of output units in the fully connected layer. + input (Variable|list of Variable): This layer's input tensor(s) which is at least 2-dimensional. + size(int): The number of output units in this layer. num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than two dimensions. If this happens, the multidimensional tensor will first be flattened into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input @@ -62,7 +62,7 @@ def fc(input, param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable parameters/weights of this layer. bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias - parameter of this layer. If set None, no bias will be added to the output units. + of this layer. If it is set to None, no bias will be added to the output units. act (str, default None): Activation to be applied to the output of this layer. name (str, default None): The name of this layer. From 93107ce138681b78c689c4e28440d4c50ff237d8 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 14 Mar 2018 10:25:08 +0800 Subject: [PATCH 13/15] add regularization for test_machine_tranlation --- python/paddle/fluid/regularizer.py | 1 - python/paddle/fluid/tests/book/test_machine_translation.py | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index dc641cdd1a..029db7d2dd 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -176,7 +176,6 @@ class L1DecayRegularizer(WeightDecayRegularizer): dtype="float32", shape=param.shape, lod_level=param.lod_level) if grad.type == core.VarDesc.VarType.SELECTED_ROWS: - # add concat_rows decay = block.create_var( dtype="float32", shape=param.shape, diff --git a/python/paddle/fluid/tests/book/test_machine_translation.py b/python/paddle/fluid/tests/book/test_machine_translation.py index caa9596a10..fa38bd3762 100644 --- a/python/paddle/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/test_machine_translation.py @@ -181,7 +181,10 @@ def train_main(use_cuda, is_sparse, is_local=True): cost = pd.cross_entropy(input=rnn_out, label=label) avg_cost = pd.mean(cost) - optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4) + optimizer = fluid.optimizer.Adagrad( + learning_rate=1e-4, + regularization=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=0.1)) optimize_ops, params_grads = optimizer.minimize(avg_cost) train_data = paddle.batch( From fc0f92c24f28693da25f716473aa93206578979b Mon Sep 17 00:00:00 2001 From: ranqiu Date: Wed, 14 Mar 2018 11:01:58 +0800 Subject: [PATCH 14/15] Update api doc std and fc doc --- doc/fluid/dev/api_doc_std_cn.md | 5 ++- doc/fluid/dev/src/fc.py | 3 +- python/paddle/fluid/layers/nn.py | 68 ++++++++++++-------------------- 3 files changed, 30 insertions(+), 46 deletions(-) diff --git a/doc/fluid/dev/api_doc_std_cn.md b/doc/fluid/dev/api_doc_std_cn.md index 9e9e77177f..5596b2653a 100644 --- a/doc/fluid/dev/api_doc_std_cn.md +++ b/doc/fluid/dev/api_doc_std_cn.md @@ -40,7 +40,7 @@ API文档须包含以下几个模块(排列顺序为文档撰写顺序): ## 格式及示例 -API文档须使用rst格式撰写,该格式详情请参考[链接](http://sphinx-doc-zh.readthedocs.io/en/latest/rest.html)。API文档各模块的内容格式及示例如下(以下以fc为例进行说明): +API文档须使用reStructuredText格式撰写,该格式详情请参考[链接](http://sphinx-doc-zh.readthedocs.io/en/latest/rest.html)。API文档各模块的内容格式及示例如下(以下以fc为例进行说明): - Python API Definition @@ -137,7 +137,8 @@ API文档须使用rst格式撰写,该格式详情请参考[链接](http://sphi ``` Args: - input (Variable|list of Variable): This layer's input tensor(s) which is at least 2-dimensional. + input (Variable|list of Variable): The input tensor(s) of this layer, and the dimension of + the input tensor(s) is at least 2. param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable parameters/weights of this layer. name (str, default None): The name of this layer. diff --git a/doc/fluid/dev/src/fc.py b/doc/fluid/dev/src/fc.py index 14a3c4cd01..3b074821cc 100644 --- a/doc/fluid/dev/src/fc.py +++ b/doc/fluid/dev/src/fc.py @@ -48,7 +48,8 @@ def fc(input, * :math:`Out`: The output tensor. Args: - input (Variable|list of Variable): This layer's input tensor(s) which is at least 2-dimensional. + input (Variable|list of Variable): The input tensor(s) of this layer, and the dimension of + the input tensor(s) is at least 2. size(int): The number of output units in this layer. num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than two dimensions. If this happens, the multidimensional tensor will first be flattened diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index ffa477ba9b..63e110251a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -85,13 +85,12 @@ def fc(input, **Fully Connected Layer** The fully connected layer can take multiple tensors as its inputs. It - creates a variable (one for each input tensor) called weights for each - input tensor, which represents a fully connected weight matrix from - each input unit to each output unit. The fully connected layer - multiplies each input tensor with its coresponding weight to produce - an output Tensor. If multiple input tensors are given, the results of - multiple multiplications will be sumed up. If bias_attr is not None, - a biases variable will be created and added to the output. Finally, + creates a variable called weights for each input tensor, which represents + a fully connected weight matrix from each input unit to each output unit. + The fully connected layer multiplies each input tensor with its coresponding + weight to produce an output Tensor. If multiple input tensors are given, + the results of multiple multiplications will be sumed up. If bias_attr is + not None, a bias variable will be created and added to the output. Finally, if activation is not None, it will be applied to the output as well. This process can be formulated as follows: @@ -110,44 +109,27 @@ def fc(input, * :math:`Out`: The output tensor. Args: - input(Variable|list): The input tensor(s) to the fully connected layer. - size(int): The number of output units in the fully connected layer. - num_flatten_dims(int): The fc layer can accept an input tensor with more - than two dimensions. If this happens, the - multidimensional tensor will first be flattened - into a 2-dimensional matrix. The parameter - `num_flatten_dims` determines how the input tensor - is flattened: the first `num_flatten_dims` - (inclusive, index starts from 1) dimensions will - be flatten to form the first dimension of the - final matrix (height of the matrix), and the rest - `rank(X) - num_flatten_dims` dimensions are - flattened to form the second dimension of the - final matrix (width of the matrix). For example, - suppose `X` is a 6-dimensional tensor with a shape - [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. Then, - the flattened matrix will have a shape - [2 x 3 x 4, 5 x 6] = [24, 30]. By default, - `num_flatten_dims` is set to 1. - param_attr(ParamAttr|list): The parameter attribute for learnable - parameters/weights of the fully connected - layer. - param_initializer(ParamAttr|list): The initializer used for the - weight/parameter. If set None, - XavierInitializer() will be used. - bias_attr(ParamAttr|list): The parameter attribute for the bias parameter - for this layer. If set None, no bias will be - added to the output units. - bias_initializer(ParamAttr|list): The initializer used for the bias. - If set None, then ConstantInitializer() - will be used. - act(str): Activation to be applied to the output of the fully connected - layer. - name(str): Name/alias of the fully connected layer. - + input (Variable|list of Variable): The input tensor(s) of this layer, and the dimension of + the input tensor(s) is at least 2. + size(int): The number of output units in this layer. + num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than + two dimensions. If this happens, the multidimensional tensor will first be flattened + into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input + tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1) + dimensions will be flatten to form the first dimension of the final matrix (height of + the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to + form the second dimension of the final matrix (width of the matrix). For example, suppose + `X` is a 6-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. + Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. + param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable + parameters/weights of this layer. + bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias + of this layer. If it is set to None, no bias will be added to the output units. + act (str, default None): Activation to be applied to the output of this layer. + name (str, default None): The name of this layer. Returns: - Variable: The output tensor variable. + A tensor variable storing the transformation result. Raises: ValueError: If rank of the input tensor is less than 2. From 7957e86cfe17033bb64452f062d6abf2f0c7e3f4 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 14 Mar 2018 11:11:38 +0800 Subject: [PATCH 15/15] fix deadlink --- doc/fluid/design/dist_train/distributed_architecture.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/design/dist_train/distributed_architecture.md b/doc/fluid/design/dist_train/distributed_architecture.md index b32b00ec25..a405cb6aaf 100644 --- a/doc/fluid/design/dist_train/distributed_architecture.md +++ b/doc/fluid/design/dist_train/distributed_architecture.md @@ -155,7 +155,7 @@ Cluster environment. `RemoteExecutor.run` sends the `ProgramDesc` and -[TrainingJob](https://github.com/PaddlePaddle/cloud/blob/develop/doc/autoscale/README.md#training-job-resource) +[TrainingJob](https://github.com/PaddlePaddle/cloud/blob/unreleased-tpr/doc/autoscale/README.md#training-job-resource) to a server in the cluster which executes `RemoteExecutor.listen`. This server is responsible to start the final Kubernetes Jobs to run the different role of `ProgramDesc` from `ConfigMap`.