|
|
|
@ -20,32 +20,48 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
class Reader {
|
|
|
|
|
class ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
Reader() {}
|
|
|
|
|
explicit Reader(const std::vector<DDim>& shapes) : shapes_(shapes) {}
|
|
|
|
|
|
|
|
|
|
virtual std::vector<LoDTensor> ReadNext() = 0;
|
|
|
|
|
virtual bool HasNext() const = 0;
|
|
|
|
|
|
|
|
|
|
virtual DDim shape(size_t idx) const;
|
|
|
|
|
virtual std::vector<DDim> shapes() const { return shapes_; }
|
|
|
|
|
virtual DDim shape(size_t idx) const = 0;
|
|
|
|
|
virtual std::vector<DDim> shapes() const = 0;
|
|
|
|
|
|
|
|
|
|
virtual ~Reader() {}
|
|
|
|
|
virtual ~ReaderBase() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
// set private to prevent directly access in decorators
|
|
|
|
|
// a decorator should access its underlying reader_'s shape, not its own.
|
|
|
|
|
class FileReader : public ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit FileReader(const std::vector<DDim>& shapes) : shapes_(shapes) {}
|
|
|
|
|
|
|
|
|
|
DDim shape(size_t idx) const override;
|
|
|
|
|
std::vector<DDim> shapes() const override { return shapes_; }
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::vector<DDim> shapes_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ReaderDecorator : public ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit ReaderDecorator(ReaderBase* reader) : reader_(reader) {}
|
|
|
|
|
|
|
|
|
|
bool HasNext() const override { return reader_->HasNext(); }
|
|
|
|
|
|
|
|
|
|
DDim shape(size_t idx) const override { return reader_->shape(idx); }
|
|
|
|
|
std::vector<DDim> shapes() const override { return reader_->shapes(); }
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
ReaderBase* reader_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// file readers
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class RandomReader : public Reader {
|
|
|
|
|
class RandomReader : public FileReader {
|
|
|
|
|
public:
|
|
|
|
|
RandomReader(const std::vector<DDim>& shapes, float min, float max)
|
|
|
|
|
: Reader(shapes), min_(min), max_(max) {
|
|
|
|
|
: FileReader(shapes), min_(min), max_(max) {
|
|
|
|
|
PADDLE_ENFORCE_LE(min, max,
|
|
|
|
|
"'min' should be less than or equal to 'max'.(%f vs %f)",
|
|
|
|
|
min, max);
|
|
|
|
@ -58,8 +74,8 @@ class RandomReader : public Reader {
|
|
|
|
|
std::uniform_real_distribution<float> dist(min_, max_);
|
|
|
|
|
|
|
|
|
|
std::vector<LoDTensor> res;
|
|
|
|
|
res.reserve(shapes().size());
|
|
|
|
|
for (const DDim& shape : shapes()) {
|
|
|
|
|
res.reserve(shapes_.size());
|
|
|
|
|
for (const DDim& shape : shapes_) {
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
shape.size(), 2,
|
|
|
|
|
"The rank of input data should be 2 at least.(Now it's %d)",
|
|
|
|
@ -85,37 +101,27 @@ class RandomReader : public Reader {
|
|
|
|
|
|
|
|
|
|
// decorators
|
|
|
|
|
|
|
|
|
|
class ShuffleReader : public Reader {
|
|
|
|
|
class ShuffleReader : public ReaderDecorator {
|
|
|
|
|
public:
|
|
|
|
|
ShuffleReader(Reader* reader, int buffer_size)
|
|
|
|
|
: reader_(reader), buffer_size_(buffer_size), iteration_pos_(0) {
|
|
|
|
|
ShuffleReader(ReaderBase* reader, int buffer_size)
|
|
|
|
|
: ReaderDecorator(reader), buffer_size_(buffer_size), iteration_pos_(0) {
|
|
|
|
|
buffer_.reserve(buffer_size);
|
|
|
|
|
}
|
|
|
|
|
std::vector<LoDTensor> ReadNext() override;
|
|
|
|
|
bool HasNext() const override { return reader_->HasNext(); }
|
|
|
|
|
|
|
|
|
|
DDim shape(size_t idx) const override { return reader_->shape(idx); }
|
|
|
|
|
std::vector<DDim> shapes() const override { return reader_->shapes(); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
Reader* reader_;
|
|
|
|
|
int buffer_size_;
|
|
|
|
|
std::vector<std::vector<LoDTensor>> buffer_;
|
|
|
|
|
size_t iteration_pos_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class BatchReader : public Reader {
|
|
|
|
|
class BatchReader : public ReaderDecorator {
|
|
|
|
|
public:
|
|
|
|
|
BatchReader(Reader* reader, int batch_size)
|
|
|
|
|
: reader_(reader), batch_size_(batch_size) {}
|
|
|
|
|
BatchReader(ReaderBase* reader, int batch_size)
|
|
|
|
|
: ReaderDecorator(reader), batch_size_(batch_size) {}
|
|
|
|
|
std::vector<LoDTensor> ReadNext() override;
|
|
|
|
|
bool HasNext() const override { return reader_->HasNext(); };
|
|
|
|
|
|
|
|
|
|
DDim shape(size_t idx) const override { return reader_->shape(idx); }
|
|
|
|
|
std::vector<DDim> shapes() const override { return reader_->shapes(); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
Reader* reader_;
|
|
|
|
|
int batch_size_;
|
|
|
|
|
std::vector<std::vector<LoDTensor>> buffer_;
|
|
|
|
|
};
|
|
|
|
|