From c48c586acacc31274cd3a50965d8e65e8213e6f1 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Sun, 8 Jul 2018 12:19:05 +0800 Subject: [PATCH] Use weak_ptr to implement DecoratedReaderChain --- paddle/fluid/framework/reader.cc | 19 +++++------ paddle/fluid/framework/reader.h | 34 ++++++++++++++----- paddle/fluid/framework/reader_test.cc | 13 ++++--- .../reader/create_batch_reader_op.cc | 4 +-- .../reader/create_custom_reader_op.cc | 8 ++--- .../reader/create_double_buffer_reader_op.cc | 3 +- .../reader/create_multi_pass_reader_op.cc | 3 +- .../operators/reader/create_py_reader_op.cc | 2 +- .../reader/create_random_data_generator_op.cc | 4 +-- .../reader/create_recordio_file_reader_op.cc | 2 +- .../reader/create_shuffle_reader_op.cc | 5 ++- .../reader/create_threaded_reader_op.cc | 3 +- .../fluid/operators/reader/open_files_op.cc | 6 ++-- 13 files changed, 62 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc index 2e2aa1cba1..e1d2ac79cf 100644 --- a/paddle/fluid/framework/reader.cc +++ b/paddle/fluid/framework/reader.cc @@ -19,15 +19,11 @@ namespace paddle { namespace framework { ReaderBase::~ReaderBase() {} -void ReaderBase::InsertDecoratedReader(ReaderBase *decorated_reader) { - decorated_readers_.emplace(decorated_reader); -} -void ReaderBase::EraseDecoratedReader(ReaderBase *decorated_reader) { - auto it = decorated_readers_.find(decorated_reader); - PADDLE_ENFORCE(it != decorated_readers_.end(), - "Cannot find the decorated reader to erase"); - decorated_readers_.erase(it); +void ReaderBase::InsertDecoratedReader( + const std::shared_ptr &decorated_reader) { + decorated_readers_.emplace_back(decorated_reader); } + std::unordered_set ReaderBase::GetEndPoints() { std::unordered_set result; std::deque queue; @@ -38,8 +34,10 @@ std::unordered_set ReaderBase::GetEndPoints() { if (front->decorated_readers_.empty()) { result.emplace(front); } else { - for (ReaderBase *reader : front->decorated_readers_) { - queue.emplace_back(reader); + for (auto &reader : front->decorated_readers_) { + if (auto *reader_ptr = reader.lock().get()) { + queue.emplace_back(reader_ptr); + } } } } @@ -66,6 +64,5 @@ void FileReader::ReadNext(std::vector *out) { } } } -DecoratedReader::~DecoratedReader() { reader_->EraseDecoratedReader(this); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 2a65c58e3f..730e3faace 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -41,24 +41,26 @@ class ReaderBase { friend class DecoratedReader; // These methods can be only invoked inside DecoratedReader to record the // decorating chain. - void InsertDecoratedReader(ReaderBase* decorated_reader); - void EraseDecoratedReader(ReaderBase* decorated_reader); + void InsertDecoratedReader( + const std::shared_ptr& decorated_reader); // A set of which readers that decorated this reader. - std::unordered_set decorated_readers_; + std::vector> decorated_readers_; }; -class DecoratedReader : public ReaderBase { +class DecoratedReader : public ReaderBase, + public std::enable_shared_from_this { public: explicit DecoratedReader(const std::shared_ptr& reader) : ReaderBase(), reader_(reader) { PADDLE_ENFORCE_NOT_NULL(reader_); - reader_->InsertDecoratedReader(this); } - ~DecoratedReader(); - void ReInit() override { reader_->ReInit(); } + void RegisterDecorateChain() { + reader_->InsertDecoratedReader(shared_from_this()); + } + protected: std::shared_ptr reader_; }; @@ -80,9 +82,14 @@ class FileReader : public ReaderBase { // making it easier to access different type reader in Variables. class ReaderHolder { public: - void Reset(ReaderBase* reader) { reader_.reset(reader); } + template + void Reset(const std::shared_ptr& reader) { + auto reader_base = std::dynamic_pointer_cast(reader); + PADDLE_ENFORCE_NOT_NULL(reader_base); + reader_ = reader_base; + } - std::shared_ptr Get() const { return reader_; } + const std::shared_ptr& Get() const { return reader_; } void ReadNext(std::vector* out) { PADDLE_ENFORCE_NOT_NULL(reader_); @@ -93,9 +100,18 @@ class ReaderHolder { reader_->ReInit(); } + operator const std::shared_ptr&() const { return this->reader_; } + private: std::shared_ptr reader_; }; +template +inline std::shared_ptr MakeDecoratedReader(ARGS&&... args) { + std::shared_ptr reader(new T(std::forward(args)...)); + reader->RegisterDecorateChain(); + return reader; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/reader_test.cc b/paddle/fluid/framework/reader_test.cc index c763fe18d6..c05be86706 100644 --- a/paddle/fluid/framework/reader_test.cc +++ b/paddle/fluid/framework/reader_test.cc @@ -32,18 +32,21 @@ class StubRootReader : public paddle::framework::ReaderBase { TEST(READER, decorate_chain) { auto root = std::make_shared(); - auto end_point1 = StubDecoratedReader(root); - auto end_point2 = StubDecoratedReader(root); + auto end_point1 = + paddle::framework::MakeDecoratedReader(root); + auto end_point2 = + paddle::framework::MakeDecoratedReader(root); { auto endpoints = root->GetEndPoints(); ASSERT_EQ(endpoints.size(), 2U); - ASSERT_NE(endpoints.count(&end_point1), 0); - ASSERT_NE(endpoints.count(&end_point2), 0); + ASSERT_NE(endpoints.count(end_point1.get()), 0); + ASSERT_NE(endpoints.count(end_point2.get()), 0); } { - auto end_point3 = StubDecoratedReader(root); + auto end_point3 = + paddle::framework::MakeDecoratedReader(root); ASSERT_EQ(root->GetEndPoints().size(), 3U); } { ASSERT_EQ(root->GetEndPoints().size(), 2U); } diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index ecbae3894d..41c3d37903 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -46,8 +46,8 @@ class CreateBatchReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new BatchReader(underlying_reader.Get(), Attr("batch_size"))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, Attr("batch_size"))); } }; diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index a75c6d4c56..81a1aa7f9c 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new CustomReader(underlying_reader.Get(), *sub_block, - Attr>("source_var_names"), - Attr>("sink_var_names"))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, *sub_block, + Attr>("source_var_names"), + Attr>("sink_var_names"))); } }; diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 5f734489a8..9382046954 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -109,7 +109,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { place = platform::CUDAPlace(static_cast(num)); } - out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, place)); } }; diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc index 19b54110b9..69b3400a84 100644 --- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -60,7 +60,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); int pass_num = Attr("pass_num"); - out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num)); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, pass_num)); } }; diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 36587360f7..0b3578570b 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -58,7 +58,7 @@ class CreatePyReaderOp : public framework::OperatorBase { auto* queue_holder = queue_holder_var->template GetMutable(); - out->Reset(new PyReader(queue_holder->GetQueue())); + out->Reset(std::make_shared(queue_holder->GetQueue())); } }; diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index 5b7e8a063a..1c3de3feab 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -79,8 +79,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { std::vector shapes = RestoreShapes(shape_concat, ranks); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RandomDataGenerator(shapes, Attr("low"), - Attr("high"))); + out->Reset(std::make_shared>( + shapes, Attr("low"), Attr("high"))); } }; diff --git a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc index 559827f084..c457cb3fb4 100644 --- a/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc +++ b/paddle/fluid/operators/reader/create_recordio_file_reader_op.cc @@ -70,7 +70,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new RecordIOFileReader( + out->Reset(std::make_shared>( filename, RestoreShapes(shape_concat, ranks))); } }; diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 57e8e21214..75adabdaa9 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -86,9 +86,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset( - new ShuffleReader(underlying_reader.Get(), - static_cast(Attr("buffer_size")))); + out->Reset(framework::MakeDecoratedReader( + underlying_reader, static_cast(Attr("buffer_size")))); } }; diff --git a/paddle/fluid/operators/reader/create_threaded_reader_op.cc b/paddle/fluid/operators/reader/create_threaded_reader_op.cc index 3798015146..81d75cdd33 100644 --- a/paddle/fluid/operators/reader/create_threaded_reader_op.cc +++ b/paddle/fluid/operators/reader/create_threaded_reader_op.cc @@ -49,7 +49,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase { } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - out->Reset(new ThreadedReader(underlying_reader.Get())); + out->Reset( + framework::MakeDecoratedReader(underlying_reader)); } }; diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index 31e5d81e55..e382066be5 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -180,9 +180,9 @@ class OpenFilesOp : public framework::OperatorBase { auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); - out->Reset(new MultiFileReader(file_names, - RestoreShapes(shape_concat, ranks), - thread_num, buffer_size)); + out->Reset(std::make_shared( + file_names, RestoreShapes(shape_concat, ranks), thread_num, + buffer_size)); } };