|
|
|
@ -21,6 +21,22 @@ namespace reader {
|
|
|
|
|
|
|
|
|
|
class MultipleReader : public framework::ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
class ThreadBufferMap {
|
|
|
|
|
public:
|
|
|
|
|
std::vector<framework::LoDTensor>& operator[](
|
|
|
|
|
const std::thread::id& thread_id) {
|
|
|
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
|
|
|
return buffer_[thread_id];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Clear() { buffer_.clear(); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::mutex mutex_;
|
|
|
|
|
std::unordered_map<std::thread::id, std::vector<framework::LoDTensor>>
|
|
|
|
|
buffer_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
MultipleReader(const std::vector<std::string>& file_names,
|
|
|
|
|
const std::vector<framework::DDim>& dims, size_t thread_num)
|
|
|
|
|
: file_names_(file_names), dims_(dims) {
|
|
|
|
@ -47,28 +63,27 @@ class MultipleReader : public framework::ReaderBase {
|
|
|
|
|
framework::Channel<size_t>* waiting_file_idx_;
|
|
|
|
|
framework::Channel<size_t>* available_thread_idx_;
|
|
|
|
|
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
|
|
|
|
|
mutable std::vector<framework::LoDTensor> local_buffer_;
|
|
|
|
|
mutable ThreadBufferMap thread_buffer_map_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
|
|
|
|
if (!HasNext()) {
|
|
|
|
|
PADDLE_THROW("There is no next data!");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (local_buffer_.empty()) {
|
|
|
|
|
buffer_->Receive(&local_buffer_);
|
|
|
|
|
}
|
|
|
|
|
*out = local_buffer_;
|
|
|
|
|
local_buffer_.clear();
|
|
|
|
|
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
|
|
|
|
|
*out = thread_local_buffer;
|
|
|
|
|
thread_local_buffer.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultipleReader::HasNext() const {
|
|
|
|
|
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
|
|
|
|
|
auto& thread_local_buffer = thread_buffer_map_[std::this_thread::get_id()];
|
|
|
|
|
return thread_local_buffer.empty() ? buffer_->Receive(&thread_local_buffer)
|
|
|
|
|
: true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MultipleReader::ReInit() {
|
|
|
|
|
EndScheduler();
|
|
|
|
|
local_buffer_.clear();
|
|
|
|
|
thread_buffer_map_.Clear();
|
|
|
|
|
StartNewScheduler();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -176,7 +191,7 @@ class OpenFilesOp : public framework::OperatorBase {
|
|
|
|
|
const auto& ranks = Attr<std::vector<int>>("ranks");
|
|
|
|
|
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
|
|
|
|
|
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
|
|
|
|
int(shape_concat.size()),
|
|
|
|
|
static_cast<int>(shape_concat.size()),
|
|
|
|
|
"The accumulate of all ranks should be equal to the "
|
|
|
|
|
"shape concat's length.");
|
|
|
|
|
const auto& file_names = Attr<std::vector<std::string>>("file_names");
|
|
|
|
|