|
|
|
@ -21,12 +21,10 @@ namespace reader {
|
|
|
|
|
|
|
|
|
|
class MultipleReader : public framework::ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
struct Quota {};
|
|
|
|
|
|
|
|
|
|
MultipleReader(const std::vector<std::string>& file_names,
|
|
|
|
|
const std::vector<framework::DDim>& dims, size_t thread_num)
|
|
|
|
|
: file_names_(file_names), dims_(dims), thread_num_(thread_num) {
|
|
|
|
|
PADDLE_ENFORCE_GT(thread_num_, 0);
|
|
|
|
|
: file_names_(file_names), dims_(dims) {
|
|
|
|
|
prefetchers_.resize(thread_num);
|
|
|
|
|
StartNewScheduler();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -34,16 +32,20 @@ class MultipleReader : public framework::ReaderBase {
|
|
|
|
|
bool HasNext() const override;
|
|
|
|
|
void ReInit() override;
|
|
|
|
|
|
|
|
|
|
~MultipleReader() { EndScheduler(); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void StartNewScheduler();
|
|
|
|
|
void EndScheduler();
|
|
|
|
|
void ScheduleThreadFunc();
|
|
|
|
|
void PrefetchThreadFunc(std::string file_name);
|
|
|
|
|
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> file_names_;
|
|
|
|
|
std::vector<framework::DDim> dims_;
|
|
|
|
|
size_t thread_num_;
|
|
|
|
|
std::thread scheduler_;
|
|
|
|
|
std::vector<std::thread> prefetchers_;
|
|
|
|
|
framework::Channel<size_t>* waiting_file_idx_;
|
|
|
|
|
framework::Channel<Quota>* thread_quotas_;
|
|
|
|
|
framework::Channel<size_t>* available_thread_idx_;
|
|
|
|
|
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
|
mutable std::vector<framework::LoDTensor> local_buffer_;
|
|
|
|
|
};
|
|
|
|
@ -65,59 +67,76 @@ bool MultipleReader::HasNext() const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultipleReader::ReInit() {
|
|
|
|
|
buffer_->Close();
|
|
|
|
|
thread_quotas_->Close();
|
|
|
|
|
waiting_file_idx_->Close();
|
|
|
|
|
EndScheduler();
|
|
|
|
|
local_buffer_.clear();
|
|
|
|
|
|
|
|
|
|
StartNewScheduler();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultipleReader::StartNewScheduler() {
|
|
|
|
|
size_t thread_num = prefetchers_.size();
|
|
|
|
|
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
|
|
|
|
|
thread_quotas_ = framework::MakeChannel<Quota>(thread_num_);
|
|
|
|
|
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
|
|
|
|
|
buffer_ =
|
|
|
|
|
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num_);
|
|
|
|
|
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < file_names_.size(); ++i) {
|
|
|
|
|
waiting_file_idx_->Send(&i);
|
|
|
|
|
}
|
|
|
|
|
waiting_file_idx_->Close();
|
|
|
|
|
for (size_t i = 0; i < thread_num_; ++i) {
|
|
|
|
|
Quota quota;
|
|
|
|
|
thread_quotas_->Send("a);
|
|
|
|
|
for (size_t i = 0; i < thread_num; ++i) {
|
|
|
|
|
available_thread_idx_->Send(&i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::thread scheduler([this] { ScheduleThreadFunc(); });
|
|
|
|
|
scheduler.detach();
|
|
|
|
|
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultipleReader::EndScheduler() {
|
|
|
|
|
available_thread_idx_->Close();
|
|
|
|
|
buffer_->Close();
|
|
|
|
|
waiting_file_idx_->Close();
|
|
|
|
|
scheduler_.join();
|
|
|
|
|
delete buffer_;
|
|
|
|
|
delete available_thread_idx_;
|
|
|
|
|
delete waiting_file_idx_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultipleReader::ScheduleThreadFunc() {
|
|
|
|
|
VLOG(5) << "MultipleReader schedule thread starts.";
|
|
|
|
|
size_t completed_thread_num = 0;
|
|
|
|
|
Quota quota;
|
|
|
|
|
while (thread_quotas_->Receive("a)) {
|
|
|
|
|
size_t thread_idx;
|
|
|
|
|
while (available_thread_idx_->Receive(&thread_idx)) {
|
|
|
|
|
std::thread& prefetcher = prefetchers_[thread_idx];
|
|
|
|
|
if (prefetcher.joinable()) {
|
|
|
|
|
prefetcher.join();
|
|
|
|
|
}
|
|
|
|
|
size_t file_idx;
|
|
|
|
|
if (waiting_file_idx_->Receive(&file_idx)) {
|
|
|
|
|
// Still have files to read. Start a new prefetch thread.
|
|
|
|
|
std::string file_name = file_names_[file_idx];
|
|
|
|
|
std::thread prefetcher(
|
|
|
|
|
[this, file_name] { PrefetchThreadFunc(file_name); });
|
|
|
|
|
prefetcher.detach();
|
|
|
|
|
prefetcher = std::thread([this, file_name, thread_idx] {
|
|
|
|
|
PrefetchThreadFunc(file_name, thread_idx);
|
|
|
|
|
});
|
|
|
|
|
} else {
|
|
|
|
|
// No more file to read.
|
|
|
|
|
++completed_thread_num;
|
|
|
|
|
if (completed_thread_num == thread_num_) {
|
|
|
|
|
thread_quotas_->Close();
|
|
|
|
|
buffer_->Close();
|
|
|
|
|
if (completed_thread_num == prefetchers_.size()) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// If users invoke ReInit() when scheduler is running, it will close the
|
|
|
|
|
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
|
|
|
|
|
// to release their resource. So a check is needed before scheduler ends.
|
|
|
|
|
for (auto& p : prefetchers_) {
|
|
|
|
|
if (p.joinable()) {
|
|
|
|
|
p.join();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "MultipleReader schedule thread terminates.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultipleReader::PrefetchThreadFunc(std::string file_name) {
|
|
|
|
|
void MultipleReader::PrefetchThreadFunc(std::string file_name,
|
|
|
|
|
size_t thread_idx) {
|
|
|
|
|
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
|
|
|
|
|
std::unique_ptr<framework::ReaderBase> reader =
|
|
|
|
|
CreateReaderByFileName(file_name, dims_);
|
|
|
|
@ -131,8 +150,10 @@ void MultipleReader::PrefetchThreadFunc(std::string file_name) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Quota quota;
|
|
|
|
|
thread_quotas_->Send("a);
|
|
|
|
|
if (!available_thread_idx_->Send(&thread_idx)) {
|
|
|
|
|
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
|
|
|
|
|
"Fail to send thread_idx.";
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|