|
|
|
@ -23,7 +23,7 @@ class BatchReader : public framework::DecoratedReader {
|
|
|
|
|
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
|
|
|
|
|
bool discard_leftover)
|
|
|
|
|
: DecoratedReader(reader),
|
|
|
|
|
batch_size_(batch_size),
|
|
|
|
|
batch_size_(static_cast<size_t>(batch_size)),
|
|
|
|
|
discard_leftover_(discard_leftover) {
|
|
|
|
|
buffer_.reserve(batch_size_);
|
|
|
|
|
}
|
|
|
|
@ -31,7 +31,7 @@ class BatchReader : public framework::DecoratedReader {
|
|
|
|
|
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int batch_size_;
|
|
|
|
|
size_t batch_size_;
|
|
|
|
|
bool discard_leftover_;
|
|
|
|
|
std::vector<std::vector<framework::LoDTensor>> buffer_;
|
|
|
|
|
};
|
|
|
|
@ -78,7 +78,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
|
|
|
|
|
void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
|
|
|
|
|
buffer_.clear();
|
|
|
|
|
buffer_.reserve(batch_size_);
|
|
|
|
|
for (int i = 0; i < batch_size_; ++i) {
|
|
|
|
|
for (size_t i = 0; i < batch_size_; ++i) {
|
|
|
|
|
buffer_.push_back(std::vector<framework::LoDTensor>());
|
|
|
|
|
reader_->ReadNext(&buffer_.back());
|
|
|
|
|
if (buffer_.back().empty()) {
|
|
|
|
@ -95,9 +95,9 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
|
|
|
|
|
// if buffer_ is empty, the 'out' will return as an empty vector.
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
int out_num = buffer_[0].size();
|
|
|
|
|
size_t out_num = buffer_[0].size();
|
|
|
|
|
out->reserve(out_num);
|
|
|
|
|
for (int j = 0; j < out_num; ++j) {
|
|
|
|
|
for (size_t j = 0; j < out_num; ++j) {
|
|
|
|
|
// Merge shape and check date type
|
|
|
|
|
std::type_index batch_type = buffer_[0][j].type();
|
|
|
|
|
framework::DDim batch_shape = buffer_[0][j].dims();
|
|
|
|
|