|
|
|
@ -25,77 +25,107 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace pybind {
|
|
|
|
|
|
|
|
|
|
class FeedReader {
|
|
|
|
|
class MultiDeviceFeedReader {
|
|
|
|
|
public:
|
|
|
|
|
using ResultDictList =
|
|
|
|
|
std::vector<std::unordered_map<std::string, framework::LoDTensor>>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
FeedReader(std::unique_ptr<framework::ReaderHolder> reader,
|
|
|
|
|
const std::vector<std::string> &names, size_t num_places,
|
|
|
|
|
bool drop_last = true)
|
|
|
|
|
: reader_(std::move(reader)),
|
|
|
|
|
MultiDeviceFeedReader(
|
|
|
|
|
const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue,
|
|
|
|
|
const std::vector<std::string> &names,
|
|
|
|
|
const std::vector<platform::Place> &dst_places, bool use_double_buffer)
|
|
|
|
|
: queue_(queue),
|
|
|
|
|
names_(names),
|
|
|
|
|
num_places_(num_places),
|
|
|
|
|
drop_last_(drop_last) {}
|
|
|
|
|
|
|
|
|
|
ResultDictList ReadNext() {
|
|
|
|
|
std::vector<framework::LoDTensor> tensors;
|
|
|
|
|
reader_->ReadNext(&tensors);
|
|
|
|
|
if (tensors.empty()) return ResultDictList();
|
|
|
|
|
pool_(new ::ThreadPool(dst_places.size())) {
|
|
|
|
|
std::shared_ptr<framework::ReaderBase> reader(
|
|
|
|
|
new operators::reader::PyReader(queue));
|
|
|
|
|
|
|
|
|
|
readers_.reserve(dst_places.size());
|
|
|
|
|
for (auto &p : dst_places) {
|
|
|
|
|
auto *holder = new framework::ReaderHolder();
|
|
|
|
|
if (use_double_buffer) {
|
|
|
|
|
holder->Reset(
|
|
|
|
|
framework::MakeDecoratedReader<operators::reader::BufferedReader>(
|
|
|
|
|
reader, p, 2));
|
|
|
|
|
} else {
|
|
|
|
|
if (platform::is_gpu_place(p)) {
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"Place cannot be CUDAPlace when use_double_buffer is False");
|
|
|
|
|
}
|
|
|
|
|
holder->Reset(reader);
|
|
|
|
|
}
|
|
|
|
|
readers_.emplace_back(holder);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(tensors.size() % names_.size() == 0,
|
|
|
|
|
"Tensor size: %d, names size: %d", tensors.size(),
|
|
|
|
|
names_.size());
|
|
|
|
|
futures_.resize(dst_places.size());
|
|
|
|
|
ret_.resize(dst_places.size());
|
|
|
|
|
ReadAsync();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t read_place_num = tensors.size() / names_.size();
|
|
|
|
|
ResultDictList ReadNext() {
|
|
|
|
|
bool success = WaitFutures();
|
|
|
|
|
|
|
|
|
|
if (drop_last_ && read_place_num != num_places_) {
|
|
|
|
|
return ResultDictList();
|
|
|
|
|
if (!success) {
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ResultDictList ret(read_place_num);
|
|
|
|
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
|
|
|
|
ret[i / names_.size()].emplace(names_[i % names_.size()],
|
|
|
|
|
std::move(tensors[i]));
|
|
|
|
|
ResultDictList result(ret_.size());
|
|
|
|
|
for (size_t i = 0; i < ret_.size(); ++i) {
|
|
|
|
|
for (size_t j = 0; j < names_.size(); ++j) {
|
|
|
|
|
result[i].emplace(names_[j], std::move(ret_[i][j]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return ret;
|
|
|
|
|
ReadAsync();
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Start() { reader_->Start(); }
|
|
|
|
|
void Reset() {
|
|
|
|
|
Shutdown();
|
|
|
|
|
Start();
|
|
|
|
|
|
|
|
|
|
void Reset() { reader_->ResetAll(); }
|
|
|
|
|
ReadAsync();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~MultiDeviceFeedReader() {
|
|
|
|
|
queue_->Close();
|
|
|
|
|
pool_.reset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<framework::ReaderHolder> reader_;
|
|
|
|
|
std::vector<std::string> names_;
|
|
|
|
|
size_t num_places_;
|
|
|
|
|
bool drop_last_;
|
|
|
|
|
};
|
|
|
|
|
bool WaitFutures() {
|
|
|
|
|
bool success = true;
|
|
|
|
|
for (auto &f : futures_) {
|
|
|
|
|
success &= f.get();
|
|
|
|
|
}
|
|
|
|
|
return success;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unique_ptr<framework::ReaderHolder> CreatePyReader(
|
|
|
|
|
const std::vector<
|
|
|
|
|
std::shared_ptr<operators::reader::LoDTensorBlockingQueue>> &queues,
|
|
|
|
|
const std::vector<platform::Place> &dst_places) {
|
|
|
|
|
std::shared_ptr<framework::ReaderBase> reader;
|
|
|
|
|
if (queues.size() == 1) {
|
|
|
|
|
reader.reset(new operators::reader::PyReader(queues[0]));
|
|
|
|
|
} else {
|
|
|
|
|
reader.reset(new operators::reader::MultiQueuePyReader(queues));
|
|
|
|
|
void Shutdown() {
|
|
|
|
|
for (auto &r : readers_) r->Shutdown();
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::shared_ptr<framework::ReaderBase>> buffered_reader;
|
|
|
|
|
buffered_reader.reserve(dst_places.size());
|
|
|
|
|
for (auto &p : dst_places) {
|
|
|
|
|
buffered_reader.emplace_back(
|
|
|
|
|
framework::MakeDecoratedReader<operators::reader::BufferedReader>(
|
|
|
|
|
reader, p, 2));
|
|
|
|
|
|
|
|
|
|
void Start() {
|
|
|
|
|
for (auto &r : readers_) r->Start();
|
|
|
|
|
}
|
|
|
|
|
reader = framework::MakeDecoratedReader<operators::reader::ComposeReader>(
|
|
|
|
|
buffered_reader);
|
|
|
|
|
|
|
|
|
|
auto *holder = new framework::ReaderHolder();
|
|
|
|
|
holder->Reset(reader);
|
|
|
|
|
return std::unique_ptr<framework::ReaderHolder>(holder);
|
|
|
|
|
}
|
|
|
|
|
void ReadAsync() {
|
|
|
|
|
for (size_t i = 0; i < readers_.size(); ++i) {
|
|
|
|
|
futures_[i] = pool_->enqueue([this, i] {
|
|
|
|
|
readers_[i]->ReadNext(&ret_[i]);
|
|
|
|
|
return !ret_[i].empty();
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> names_;
|
|
|
|
|
std::unique_ptr<::ThreadPool> pool_;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_;
|
|
|
|
|
std::vector<std::unique_ptr<framework::ReaderHolder>> readers_;
|
|
|
|
|
std::vector<std::future<bool>> futures_;
|
|
|
|
|
std::vector<std::vector<framework::LoDTensor>> ret_;
|
|
|
|
|
bool drop_last_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
|
|
|
|
|
@ -108,22 +138,20 @@ void BindReader(py::module *module) {
|
|
|
|
|
.def("start", &framework::ReaderHolder::Start)
|
|
|
|
|
.def("reset", &framework::ReaderHolder::ResetAll);
|
|
|
|
|
|
|
|
|
|
py::class_<FeedReader>(m, "FeedReader", "")
|
|
|
|
|
.def("read_next", &FeedReader::ReadNext,
|
|
|
|
|
py::call_guard<py::gil_scoped_release>())
|
|
|
|
|
.def("start", &FeedReader::Start,
|
|
|
|
|
py::class_<MultiDeviceFeedReader>(m, "MultiDeviceFeedReader", "")
|
|
|
|
|
.def("read_next", &MultiDeviceFeedReader::ReadNext,
|
|
|
|
|
py::call_guard<py::gil_scoped_release>())
|
|
|
|
|
.def("reset", &FeedReader::Reset,
|
|
|
|
|
.def("reset", &MultiDeviceFeedReader::Reset,
|
|
|
|
|
py::call_guard<py::gil_scoped_release>());
|
|
|
|
|
|
|
|
|
|
m.def("create_py_reader",
|
|
|
|
|
[](const std::vector<
|
|
|
|
|
std::shared_ptr<operators::reader::LoDTensorBlockingQueue>>
|
|
|
|
|
queues,
|
|
|
|
|
[](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue>
|
|
|
|
|
&queue,
|
|
|
|
|
const std::vector<std::string> &names,
|
|
|
|
|
const std::vector<platform::Place> &dst_places, bool drop_last) {
|
|
|
|
|
return new FeedReader(CreatePyReader(queues, dst_places), names,
|
|
|
|
|
dst_places.size(), drop_last);
|
|
|
|
|
const std::vector<platform::Place> &dst_places,
|
|
|
|
|
bool use_double_buffer) {
|
|
|
|
|
return new MultiDeviceFeedReader(queues, names, dst_places,
|
|
|
|
|
use_double_buffer);
|
|
|
|
|
},
|
|
|
|
|
py::return_value_policy::take_ownership);
|
|
|
|
|
}
|
|
|
|
|