|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
|
|
|
|
|
#include <thread> // NOLINT
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/channel.h"
|
|
|
|
|
#include "paddle/fluid/operators/reader/blocking_queue.h"
|
|
|
|
|
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -48,9 +48,9 @@ class MultiFileReader : public framework::ReaderBase {
|
|
|
|
|
std::thread scheduler_;
|
|
|
|
|
std::vector<std::thread> prefetchers_;
|
|
|
|
|
size_t buffer_size_;
|
|
|
|
|
framework::Channel<size_t>* waiting_file_idx_;
|
|
|
|
|
framework::Channel<size_t>* available_thread_idx_;
|
|
|
|
|
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
|
reader::BlockingQueue<size_t>* waiting_file_idx_;
|
|
|
|
|
reader::BlockingQueue<size_t>* available_thread_idx_;
|
|
|
|
|
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
|
|
|
@ -73,17 +73,17 @@ bool MultiFileReader::HasNext() {
|
|
|
|
|
|
|
|
|
|
void MultiFileReader::StartNewScheduler() {
|
|
|
|
|
size_t thread_num = prefetchers_.size();
|
|
|
|
|
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
|
|
|
|
|
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
|
|
|
|
|
buffer_ =
|
|
|
|
|
framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
|
|
|
|
|
waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size());
|
|
|
|
|
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
|
|
|
|
|
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
|
|
|
|
|
buffer_size_);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < file_names_.size(); ++i) {
|
|
|
|
|
waiting_file_idx_->Send(&i);
|
|
|
|
|
waiting_file_idx_->Send(i);
|
|
|
|
|
}
|
|
|
|
|
waiting_file_idx_->Close();
|
|
|
|
|
for (size_t i = 0; i < thread_num; ++i) {
|
|
|
|
|
available_thread_idx_->Send(&i);
|
|
|
|
|
available_thread_idx_->Send(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
|
|
|
|
@ -149,7 +149,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
try {
|
|
|
|
|
buffer_->Send(&ins);
|
|
|
|
|
buffer_->Send(std::move(ins));
|
|
|
|
|
} catch (paddle::platform::EnforceNotMet e) {
|
|
|
|
|
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
|
|
|
|
|
"thread of file '"
|
|
|
|
@ -158,9 +158,7 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
available_thread_idx_->Send(&thread_idx);
|
|
|
|
|
} catch (paddle::platform::EnforceNotMet e) {
|
|
|
|
|
if (!available_thread_idx_->Send(thread_idx)) {
|
|
|
|
|
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
|
|
|
|
|
"Fail to send thread_idx.";
|
|
|
|
|
}
|
|
|
|
|