|
|
@ -12,42 +12,35 @@
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
#include <condition_variable>
|
|
|
|
|
|
|
|
#include <mutex>
|
|
|
|
|
|
|
|
#include <thread>
|
|
|
|
#include <thread>
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/channel.h"
|
|
|
|
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
|
|
|
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
namespace reader {
|
|
|
|
namespace reader {
|
|
|
|
|
|
|
|
|
|
|
|
static constexpr size_t kDoubleBufferSize = 3;
|
|
|
|
static constexpr size_t kDoubleBufferSize = 2;
|
|
|
|
|
|
|
|
|
|
|
|
class DoubleBufferReader : public framework::DecoratedReader {
|
|
|
|
class DoubleBufferReader : public framework::DecoratedReader {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
explicit DoubleBufferReader(ReaderBase* reader)
|
|
|
|
explicit DoubleBufferReader(ReaderBase* reader)
|
|
|
|
: DecoratedReader(reader),
|
|
|
|
: DecoratedReader(reader),
|
|
|
|
buffer_(kDoubleBufferSize),
|
|
|
|
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
|
|
|
|
write_pos_(0),
|
|
|
|
kDoubleBufferSize)) {
|
|
|
|
read_pos_(0) {
|
|
|
|
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
|
|
|
|
std::thread prefetch(
|
|
|
|
|
|
|
|
std::bind(&DoubleBufferReader::PrefetchThreadFunc, this));
|
|
|
|
|
|
|
|
prefetch.detach();
|
|
|
|
prefetch.detach();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
|
|
|
void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
|
|
|
bool HasNext() const override;
|
|
|
|
void ReInit() override;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
~DoubleBufferReader() { buffer_->Close(); }
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
void PrefetchThreadFunc();
|
|
|
|
void PrefetchThreadFunc();
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<framework::LoDTensor>> buffer_;
|
|
|
|
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
size_t write_pos_;
|
|
|
|
|
|
|
|
size_t read_pos_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::mutex mtx_;
|
|
|
|
|
|
|
|
std::condition_variable buffer_not_full_;
|
|
|
|
|
|
|
|
std::condition_variable buffer_not_empty_;
|
|
|
|
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
|
|
|
|
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
|
|
|
@ -80,44 +73,36 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
|
|
|
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
|
|
|
std::unique_lock<std::mutex> lck(mtx_);
|
|
|
|
|
|
|
|
while (write_pos_ == read_pos_) {
|
|
|
|
|
|
|
|
buffer_not_empty_.wait(lck);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out->clear();
|
|
|
|
out->clear();
|
|
|
|
out->reserve(buffer_[read_pos_].size());
|
|
|
|
buffer_->Receive(out);
|
|
|
|
// TODO(fengjiayi): This copy shall be reduced.
|
|
|
|
|
|
|
|
for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) {
|
|
|
|
|
|
|
|
framework::LoDTensor dst;
|
|
|
|
|
|
|
|
TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &dst);
|
|
|
|
|
|
|
|
dst.set_lod(buffer_[read_pos_][i].lod());
|
|
|
|
|
|
|
|
out->push_back(dst);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
++read_pos_;
|
|
|
|
|
|
|
|
if (read_pos_ >= kDoubleBufferSize) {
|
|
|
|
|
|
|
|
read_pos_ = 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
buffer_not_full_.notify_all();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool DoubleBufferReader::HasNext() const {
|
|
|
|
void DoubleBufferReader::ReInit() {
|
|
|
|
return reader_->HasNext() || !buffer_.empty();
|
|
|
|
reader_->ReInit();
|
|
|
|
|
|
|
|
buffer_->Close();
|
|
|
|
|
|
|
|
// The existing prefetch thread will terminate for the buffer_ is closed.
|
|
|
|
|
|
|
|
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
|
|
|
|
|
|
|
|
kDoubleBufferSize);
|
|
|
|
|
|
|
|
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
|
|
|
|
|
|
|
|
prefetch.detach();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void DoubleBufferReader::PrefetchThreadFunc() {
|
|
|
|
void DoubleBufferReader::PrefetchThreadFunc() {
|
|
|
|
while (reader_->HasNext()) {
|
|
|
|
VLOG(5) << "A new prefetch thread starts.";
|
|
|
|
std::unique_lock<std::mutex> lck(mtx_);
|
|
|
|
while (true) {
|
|
|
|
while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) {
|
|
|
|
std::vector<framework::LoDTensor> batch;
|
|
|
|
buffer_not_full_.wait(lck);
|
|
|
|
reader_->ReadNext(&batch);
|
|
|
|
|
|
|
|
if (batch.empty()) {
|
|
|
|
|
|
|
|
// EOF
|
|
|
|
|
|
|
|
buffer_->Close();
|
|
|
|
|
|
|
|
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
|
|
|
|
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
reader_->ReadNext(&buffer_[write_pos_]);
|
|
|
|
if (!buffer_->Send(&batch)) {
|
|
|
|
++write_pos_;
|
|
|
|
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
|
|
|
|
if (write_pos_ >= kDoubleBufferSize) {
|
|
|
|
"prefetch thread terminates.";
|
|
|
|
write_pos_ = 0;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
buffer_not_empty_.notify_all();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|