|
|
|
@ -28,15 +28,15 @@ BufferedReader::BufferedReader(
|
|
|
|
|
buffer_size_(buffer_size) {
|
|
|
|
|
cpu_buffer_.resize(buffer_size);
|
|
|
|
|
gpu_buffer_.resize(buffer_size);
|
|
|
|
|
AppendFutureToBatchSize();
|
|
|
|
|
ReadTillBufferFullAsync();
|
|
|
|
|
}
|
|
|
|
|
void BufferedReader::AppendFutureToBatchSize() {
|
|
|
|
|
void BufferedReader::ReadTillBufferFullAsync() {
|
|
|
|
|
PADDLE_ENFORCE_EQ(position_.size(), 0U);
|
|
|
|
|
for (size_t i = 0; i < buffer_size_; ++i) {
|
|
|
|
|
AppendFuture(i);
|
|
|
|
|
ReadAsync(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
void BufferedReader::AppendFuture(size_t i) {
|
|
|
|
|
void BufferedReader::ReadAsync(size_t i) {
|
|
|
|
|
position_.emplace(thread_pool_.enqueue([this, i]() -> size_t {
|
|
|
|
|
TensorVec &cpu = cpu_buffer_[i];
|
|
|
|
|
reader_->ReadNext(&cpu);
|
|
|
|
@ -50,6 +50,7 @@ void BufferedReader::AppendFuture(size_t i) {
|
|
|
|
|
gpu.resize(cpu.size());
|
|
|
|
|
for (size_t i = 0; i < cpu.size(); ++i) {
|
|
|
|
|
framework::TensorCopySync(cpu[i], place_, &gpu[i]);
|
|
|
|
|
gpu[i].set_lod(cpu[i].lod());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return i;
|
|
|
|
@ -60,10 +61,11 @@ void BufferedReader::ShutdownImpl() {
|
|
|
|
|
while (!position_.empty()) {
|
|
|
|
|
position_.pop();
|
|
|
|
|
}
|
|
|
|
|
prev_pos_ = -1UL;
|
|
|
|
|
}
|
|
|
|
|
void BufferedReader::StartImpl() {
|
|
|
|
|
reader_->Start();
|
|
|
|
|
AppendFutureToBatchSize();
|
|
|
|
|
ReadTillBufferFullAsync();
|
|
|
|
|
}
|
|
|
|
|
void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
|
|
|
|
|
if (position_.empty()) {
|
|
|
|
@ -79,7 +81,14 @@ void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*out = platform::is_gpu_place(place_) ? gpu_buffer_[i] : cpu_buffer_[i];
|
|
|
|
|
AppendFuture(i);
|
|
|
|
|
|
|
|
|
|
// Do not push current position into ReadAsync. Push the previous position
|
|
|
|
|
// Since all computation in fluid are async, change the data of
|
|
|
|
|
// current position may cause data error.
|
|
|
|
|
if (prev_pos_ != -1Ul) {
|
|
|
|
|
ReadAsync(prev_pos_);
|
|
|
|
|
}
|
|
|
|
|
prev_pos_ = i;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace reader
|
|
|
|
|