|
|
|
@ -38,8 +38,9 @@ class MultipleReader : public framework::ReaderBase {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
MultipleReader(const std::vector<std::string>& file_names,
|
|
|
|
|
const std::vector<framework::DDim>& dims, size_t thread_num)
|
|
|
|
|
: file_names_(file_names), dims_(dims) {
|
|
|
|
|
const std::vector<framework::DDim>& dims, size_t thread_num,
|
|
|
|
|
size_t buffer_size)
|
|
|
|
|
: file_names_(file_names), dims_(dims), buffer_size_(buffer_size) {
|
|
|
|
|
prefetchers_.resize(thread_num);
|
|
|
|
|
StartNewScheduler();
|
|
|
|
|
}
|
|
|
|
@ -60,6 +61,7 @@ class MultipleReader : public framework::ReaderBase {
|
|
|
|
|
std::vector<framework::DDim> dims_;
|
|
|
|
|
std::thread scheduler_;
|
|
|
|
|
std::vector<std::thread> prefetchers_;
|
|
|
|
|
size_t buffer_size_;
|
|
|
|
|
framework::Channel<size_t>* waiting_file_idx_;
|
|
|
|
|
framework::Channel<size_t>* available_thread_idx_;
|
|
|
|
|
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
@ -92,7 +94,7 @@ void MultipleReader::StartNewScheduler() {
|
|
|
|
|
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
|
|
|
|
|
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
|
|
|
|
|
buffer_ =
|
|
|
|
|
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);
|
|
|
|
|
framework::MakeChannel<std::vector<framework::LoDTensor>>(buffer_size_);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < file_names_.size(); ++i) {
|
|
|
|
|
waiting_file_idx_->Send(&i);
|
|
|
|
@ -197,11 +199,13 @@ class OpenFilesOp : public framework::OperatorBase {
|
|
|
|
|
const auto& file_names = Attr<std::vector<std::string>>("file_names");
|
|
|
|
|
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
|
|
|
|
|
const size_t thread_num = Attr<int>("thread_num");
|
|
|
|
|
const size_t buffer_size = Attr<int>("buffer_size");
|
|
|
|
|
|
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
|
->template GetMutable<framework::ReaderHolder>();
|
|
|
|
|
out->Reset(new MultipleReader(
|
|
|
|
|
file_names, RestoreShapes(shape_concat, ranks), thread_num));
|
|
|
|
|
out->Reset(new MultipleReader(file_names,
|
|
|
|
|
RestoreShapes(shape_concat, ranks),
|
|
|
|
|
thread_num, buffer_size));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -212,6 +216,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
|
|
|
|
|
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
|
|
|
|
|
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
|
|
|
|
|
.GreaterThan(0);
|
|
|
|
|
AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
|
|
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
OpenFiles Operator
|
|
|
|
|