|
|
|
@ -24,15 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;
|
|
|
|
|
|
|
|
|
|
class DoubleBufferReader : public framework::DecoratedReader {
|
|
|
|
|
public:
|
|
|
|
|
struct Item {
|
|
|
|
|
Item() : ctx_(nullptr) {}
|
|
|
|
|
|
|
|
|
|
std::vector<framework::LoDTensor> payloads_;
|
|
|
|
|
platform::DeviceContext* ctx_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
explicit DoubleBufferReader(
|
|
|
|
|
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
|
|
|
|
|
: DecoratedReader(reader), place_(target_place) {
|
|
|
|
|
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
|
|
|
|
|
if (platform::is_gpu_place(place_)) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
ctxs_.emplace_back(new platform::CUDADeviceContext(
|
|
|
|
|
boost::get<platform::CUDAPlace>(place_)));
|
|
|
|
|
#else
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
start_thread();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void start_thread() {
|
|
|
|
|
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
|
|
|
|
|
kDoubleBufferSize);
|
|
|
|
|
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
|
|
|
|
|
std::thread prefetch([this] { PrefetchThreadFunc(); });
|
|
|
|
|
prefetch.detach();
|
|
|
|
|
}
|
|
|
|
@ -47,9 +63,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
|
|
|
|
|
private:
|
|
|
|
|
void PrefetchThreadFunc();
|
|
|
|
|
|
|
|
|
|
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
|
framework::Channel<Item>* buffer_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
mutable std::vector<framework::LoDTensor> local_buffer_;
|
|
|
|
|
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
|
|
|
|
|
mutable Item local_buffer_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CreateDoubleBufferReaderOp : public framework::OperatorBase {
|
|
|
|
@ -104,12 +121,14 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
|
|
|
|
out->clear();
|
|
|
|
|
if (local_buffer_.empty()) {
|
|
|
|
|
buffer_->Receive(out);
|
|
|
|
|
} else {
|
|
|
|
|
*out = local_buffer_;
|
|
|
|
|
local_buffer_.clear();
|
|
|
|
|
if (local_buffer_.payloads_.empty()) {
|
|
|
|
|
buffer_->Receive(&local_buffer_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*out = local_buffer_.payloads_;
|
|
|
|
|
local_buffer_.payloads_.clear();
|
|
|
|
|
if (local_buffer_.ctx_) {
|
|
|
|
|
local_buffer_.ctx_->Wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -121,16 +140,22 @@ void DoubleBufferReader::ReInit() {
|
|
|
|
|
|
|
|
|
|
void DoubleBufferReader::PrefetchThreadFunc() {
|
|
|
|
|
VLOG(5) << "A new prefetch thread starts.";
|
|
|
|
|
size_t gpu_ctx_offset = 0;
|
|
|
|
|
while (reader_->HasNext()) {
|
|
|
|
|
std::vector<framework::LoDTensor> batch;
|
|
|
|
|
reader_->ReadNext(&batch);
|
|
|
|
|
Item batch;
|
|
|
|
|
reader_->ReadNext(&batch.payloads_);
|
|
|
|
|
if (platform::is_gpu_place(place_)) {
|
|
|
|
|
std::vector<framework::LoDTensor> gpu_batch;
|
|
|
|
|
gpu_batch.resize(batch.size());
|
|
|
|
|
for (size_t i = 0; i < batch.size(); ++i) {
|
|
|
|
|
framework::TensorCopy(batch[i], place_, &gpu_batch[i]);
|
|
|
|
|
gpu_batch[i].set_lod(batch[i].lod());
|
|
|
|
|
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
|
|
|
|
|
gpu_ctx_offset %= this->ctxs_.size();
|
|
|
|
|
gpu_batch.resize(batch.payloads_.size());
|
|
|
|
|
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
|
|
|
|
|
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
|
|
|
|
|
&gpu_batch[i]);
|
|
|
|
|
gpu_batch[i].set_lod(batch.payloads_[i].lod());
|
|
|
|
|
}
|
|
|
|
|
batch.ctx_ = gpu_ctx.get();
|
|
|
|
|
std::swap(gpu_batch, batch.payloads_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!buffer_->Send(&batch)) {
|
|
|
|
@ -143,7 +168,7 @@ void DoubleBufferReader::PrefetchThreadFunc() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool DoubleBufferReader::HasNext() const {
|
|
|
|
|
if (local_buffer_.empty()) {
|
|
|
|
|
if (local_buffer_.payloads_.empty()) {
|
|
|
|
|
bool ok = buffer_->Receive(&local_buffer_);
|
|
|
|
|
return ok;
|
|
|
|
|
} else {
|
|
|
|
|