|
|
@ -24,11 +24,16 @@ static constexpr size_t kDoubleBufferSize = 2;
|
|
|
|
|
|
|
|
|
|
|
|
class DoubleBufferReader : public framework::DecoratedReader {
|
|
|
|
class DoubleBufferReader : public framework::DecoratedReader {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
explicit DoubleBufferReader(ReaderBase* reader)
|
|
|
|
explicit DoubleBufferReader(
|
|
|
|
: DecoratedReader(reader),
|
|
|
|
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
|
|
|
|
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
|
|
|
|
: DecoratedReader(reader), place_(target_place) {
|
|
|
|
kDoubleBufferSize)) {
|
|
|
|
start_thread();
|
|
|
|
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void start_thread() {
|
|
|
|
|
|
|
|
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
|
|
|
|
|
|
|
|
kDoubleBufferSize);
|
|
|
|
|
|
|
|
std::thread prefetch([this] { PrefetchThreadFunc(); });
|
|
|
|
prefetch.detach();
|
|
|
|
prefetch.detach();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -43,6 +48,8 @@ class DoubleBufferReader : public framework::DecoratedReader {
|
|
|
|
void PrefetchThreadFunc();
|
|
|
|
void PrefetchThreadFunc();
|
|
|
|
|
|
|
|
|
|
|
|
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
|
|
|
mutable std::vector<framework::LoDTensor> local_buffer_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
|
|
|
|
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
|
|
|
@ -56,7 +63,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
|
|
|
|
->Get<framework::ReaderHolder>();
|
|
|
|
->Get<framework::ReaderHolder>();
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
->template GetMutable<framework::ReaderHolder>();
|
|
|
|
->template GetMutable<framework::ReaderHolder>();
|
|
|
|
out->Reset(new DoubleBufferReader(underlying_reader.Get()));
|
|
|
|
|
|
|
|
|
|
|
|
auto place_str = Attr<std::string>("place");
|
|
|
|
|
|
|
|
platform::Place place;
|
|
|
|
|
|
|
|
if (place_str == "CPU") {
|
|
|
|
|
|
|
|
place = platform::CPUPlace();
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
std::istringstream sin(place_str);
|
|
|
|
|
|
|
|
sin.seekg(std::string("CUDA:").size(), std::ios::beg);
|
|
|
|
|
|
|
|
size_t num;
|
|
|
|
|
|
|
|
sin >> num;
|
|
|
|
|
|
|
|
place = platform::CUDAPlace(static_cast<int>(num));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -71,44 +91,65 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
|
|
|
|
It launches another thread to execute the 'underlying reader' asynchronously,
|
|
|
|
It launches another thread to execute the 'underlying reader' asynchronously,
|
|
|
|
which prevents reading process from blocking subsequent training.
|
|
|
|
which prevents reading process from blocking subsequent training.
|
|
|
|
)DOC");
|
|
|
|
)DOC");
|
|
|
|
|
|
|
|
std::unordered_set<std::string> enum_range;
|
|
|
|
|
|
|
|
constexpr size_t kMaxCUDADevs = 128;
|
|
|
|
|
|
|
|
for (size_t i = 0; i < kMaxCUDADevs; ++i) {
|
|
|
|
|
|
|
|
enum_range.insert(string::Sprintf("CUDA:%d", i));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
enum_range.insert("CPU");
|
|
|
|
|
|
|
|
AddAttr<std::string>("place", "The double buffer place, default is CPU")
|
|
|
|
|
|
|
|
.SetDefault("CPU")
|
|
|
|
|
|
|
|
.InEnum({enum_range});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
|
|
|
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
|
|
|
out->clear();
|
|
|
|
out->clear();
|
|
|
|
buffer_->Receive(out);
|
|
|
|
if (local_buffer_.empty()) {
|
|
|
|
|
|
|
|
buffer_->Receive(out);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
*out = local_buffer_;
|
|
|
|
|
|
|
|
local_buffer_.clear();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void DoubleBufferReader::ReInit() {
|
|
|
|
void DoubleBufferReader::ReInit() {
|
|
|
|
reader_->ReInit();
|
|
|
|
reader_->ReInit();
|
|
|
|
buffer_->Close();
|
|
|
|
buffer_->Close();
|
|
|
|
// The existing prefetch thread will terminate for the buffer_ is closed.
|
|
|
|
start_thread();
|
|
|
|
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
|
|
|
|
|
|
|
|
kDoubleBufferSize);
|
|
|
|
|
|
|
|
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
|
|
|
|
|
|
|
|
prefetch.detach();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void DoubleBufferReader::PrefetchThreadFunc() {
|
|
|
|
void DoubleBufferReader::PrefetchThreadFunc() {
|
|
|
|
VLOG(5) << "A new prefetch thread starts.";
|
|
|
|
VLOG(5) << "A new prefetch thread starts.";
|
|
|
|
while (true) {
|
|
|
|
while (reader_->HasNext()) {
|
|
|
|
std::vector<framework::LoDTensor> batch;
|
|
|
|
std::vector<framework::LoDTensor> batch;
|
|
|
|
reader_->ReadNext(&batch);
|
|
|
|
reader_->ReadNext(&batch);
|
|
|
|
if (batch.empty()) {
|
|
|
|
if (platform::is_gpu_place(place_)) {
|
|
|
|
// EOF
|
|
|
|
std::vector<framework::LoDTensor> gpu_batch;
|
|
|
|
buffer_->Close();
|
|
|
|
gpu_batch.resize(batch.size());
|
|
|
|
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
|
|
|
|
for (size_t i = 0; i < batch.size(); ++i) {
|
|
|
|
break;
|
|
|
|
framework::TensorCopy(batch[i], place_, &gpu_batch[i]);
|
|
|
|
|
|
|
|
gpu_batch[i].set_lod(batch[i].lod());
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (!buffer_->Send(&batch)) {
|
|
|
|
if (!buffer_->Send(&batch)) {
|
|
|
|
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
|
|
|
|
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
|
|
|
|
"prefetch thread terminates.";
|
|
|
|
"prefetch thread terminates.";
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
buffer_->Close();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
|
|
|
|
bool DoubleBufferReader::HasNext() const {
|
|
|
|
|
|
|
|
if (local_buffer_.empty()) {
|
|
|
|
|
|
|
|
bool ok = buffer_->Receive(&local_buffer_);
|
|
|
|
|
|
|
|
return ok;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace reader
|
|
|
|
} // namespace reader
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|