|
|
|
@ -22,34 +22,18 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
class ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
|
|
|
|
|
PADDLE_ENFORCE(!shapes_.empty());
|
|
|
|
|
}
|
|
|
|
|
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
|
|
|
|
|
|
|
|
|
|
virtual void ReInit() = 0;
|
|
|
|
|
|
|
|
|
|
DDim shape(size_t idx) const;
|
|
|
|
|
std::vector<DDim> shapes() const { return shapes_; }
|
|
|
|
|
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
|
|
|
|
|
|
|
|
|
|
virtual bool HasNext() const = 0;
|
|
|
|
|
|
|
|
|
|
virtual ~ReaderBase() {}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::vector<DDim> shapes_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class FileReader : public ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
|
|
|
|
|
virtual ~ReaderBase();
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class DecoratedReader : public ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit DecoratedReader(ReaderBase* reader)
|
|
|
|
|
: ReaderBase(reader->shapes()), reader_(reader) {
|
|
|
|
|
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(reader_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -72,12 +56,6 @@ class ReaderHolder {
|
|
|
|
|
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
|
|
|
|
|
void ReInit() { reader_->ReInit(); }
|
|
|
|
|
|
|
|
|
|
DDim shape(size_t idx) const { return reader_->shape(idx); }
|
|
|
|
|
std::vector<DDim> shapes() const { return reader_->shapes(); }
|
|
|
|
|
void set_shapes(const std::vector<DDim>& shapes) {
|
|
|
|
|
reader_->set_shapes(shapes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasNext() const { return reader_->HasNext(); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|