|
|
|
@ -26,7 +26,11 @@ class MultiFileReader : public framework::ReaderBase {
|
|
|
|
|
MultiFileReader(const std::vector<std::string>& file_names,
|
|
|
|
|
const std::vector<framework::DDim>& dims, size_t thread_num,
|
|
|
|
|
size_t buffer_size)
|
|
|
|
|
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
|
|
|
|
|
: buffer_size_(buffer_size) {
|
|
|
|
|
readers_.resize(file_names.size());
|
|
|
|
|
for (const std::string& f_name : file_names) {
|
|
|
|
|
readers_.emplace_back(CreateReaderByFileName(f_name, dims));
|
|
|
|
|
}
|
|
|
|
|
prefetchers_.resize(thread_num);
|
|
|
|
|
StartNewScheduler();
|
|
|
|
|
}
|
|
|
|
@ -40,14 +44,13 @@ class MultiFileReader : public framework::ReaderBase {
|
|
|
|
|
void StartNewScheduler();
|
|
|
|
|
void EndScheduler();
|
|
|
|
|
void ScheduleThreadFunc();
|
|
|
|
|
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
|
|
|
|
|
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx);
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> file_names_;
|
|
|
|
|
std::vector<framework::DDim> dims_;
|
|
|
|
|
std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
|
|
|
|
|
std::thread scheduler_;
|
|
|
|
|
std::vector<std::thread> prefetchers_;
|
|
|
|
|
size_t buffer_size_;
|
|
|
|
|
reader::BlockingQueue<size_t>* waiting_file_idx_;
|
|
|
|
|
reader::BlockingQueue<size_t>* waiting_reader_idx_;
|
|
|
|
|
reader::BlockingQueue<size_t>* available_thread_idx_;
|
|
|
|
|
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
|
};
|
|
|
|
@ -60,20 +63,23 @@ void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
|
|
|
|
|
|
|
|
|
void MultiFileReader::ReInit() {
|
|
|
|
|
EndScheduler();
|
|
|
|
|
for (auto& reader : readers_) {
|
|
|
|
|
reader->ReInit();
|
|
|
|
|
}
|
|
|
|
|
StartNewScheduler();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiFileReader::StartNewScheduler() {
|
|
|
|
|
size_t thread_num = prefetchers_.size();
|
|
|
|
|
waiting_file_idx_ = new reader::BlockingQueue<size_t>(file_names_.size());
|
|
|
|
|
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
|
|
|
|
|
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num);
|
|
|
|
|
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>(
|
|
|
|
|
buffer_size_);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < file_names_.size(); ++i) {
|
|
|
|
|
waiting_file_idx_->Send(i);
|
|
|
|
|
for (size_t i = 0; i < readers_.size(); ++i) {
|
|
|
|
|
waiting_reader_idx_->Send(i);
|
|
|
|
|
}
|
|
|
|
|
waiting_file_idx_->Close();
|
|
|
|
|
waiting_reader_idx_->Close();
|
|
|
|
|
for (size_t i = 0; i < thread_num; ++i) {
|
|
|
|
|
available_thread_idx_->Send(i);
|
|
|
|
|
}
|
|
|
|
@ -84,13 +90,13 @@ void MultiFileReader::StartNewScheduler() {
|
|
|
|
|
void MultiFileReader::EndScheduler() {
|
|
|
|
|
available_thread_idx_->Close();
|
|
|
|
|
buffer_->Close();
|
|
|
|
|
waiting_file_idx_->Close();
|
|
|
|
|
waiting_reader_idx_->Close();
|
|
|
|
|
if (scheduler_.joinable()) {
|
|
|
|
|
scheduler_.join();
|
|
|
|
|
}
|
|
|
|
|
delete buffer_;
|
|
|
|
|
delete available_thread_idx_;
|
|
|
|
|
delete waiting_file_idx_;
|
|
|
|
|
delete waiting_reader_idx_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiFileReader::ScheduleThreadFunc() {
|
|
|
|
@ -102,12 +108,11 @@ void MultiFileReader::ScheduleThreadFunc() {
|
|
|
|
|
if (prefetcher.joinable()) {
|
|
|
|
|
prefetcher.join();
|
|
|
|
|
}
|
|
|
|
|
size_t file_idx;
|
|
|
|
|
if (waiting_file_idx_->Receive(&file_idx)) {
|
|
|
|
|
size_t reader_idx;
|
|
|
|
|
if (waiting_reader_idx_->Receive(&reader_idx)) {
|
|
|
|
|
// Still have files to read. Start a new prefetch thread.
|
|
|
|
|
std::string file_name = file_names_[file_idx];
|
|
|
|
|
prefetcher = std::thread([this, file_name, thread_idx] {
|
|
|
|
|
PrefetchThreadFunc(file_name, thread_idx);
|
|
|
|
|
prefetcher = std::thread([this, reader_idx, thread_idx] {
|
|
|
|
|
PrefetchThreadFunc(reader_idx, thread_idx);
|
|
|
|
|
});
|
|
|
|
|
} else {
|
|
|
|
|
// No more file to read.
|
|
|
|
@ -129,11 +134,9 @@ void MultiFileReader::ScheduleThreadFunc() {
|
|
|
|
|
VLOG(5) << "MultiFileReader schedule thread terminates.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultiFileReader::PrefetchThreadFunc(std::string file_name,
|
|
|
|
|
size_t thread_idx) {
|
|
|
|
|
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
|
|
|
|
|
std::unique_ptr<framework::ReaderBase> reader =
|
|
|
|
|
CreateReaderByFileName(file_name, dims_);
|
|
|
|
|
void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
|
|
|
|
|
VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts.";
|
|
|
|
|
std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx];
|
|
|
|
|
while (true) {
|
|
|
|
|
std::vector<framework::LoDTensor> ins;
|
|
|
|
|
reader->ReadNext(&ins);
|
|
|
|
@ -144,8 +147,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
|
|
|
|
|
buffer_->Send(std::move(ins));
|
|
|
|
|
} catch (paddle::platform::EnforceNotMet e) {
|
|
|
|
|
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
|
|
|
|
|
"thread of file '"
|
|
|
|
|
<< file_name << "' will terminate.";
|
|
|
|
|
"thread of file idx '"
|
|
|
|
|
<< reader_idx << "' will terminate.";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -154,7 +157,8 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
|
|
|
|
|
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
|
|
|
|
|
"Fail to send thread_idx.";
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
|
|
|
|
|
VLOG(5) << "The prefetch thread of file idx '" << reader_idx
|
|
|
|
|
<< "' terminates.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class OpenFilesOp : public framework::OperatorBase {
|
|
|
|
|