|
|
|
@ -20,6 +20,7 @@
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ddim.h"
|
|
|
|
|
#include "paddle/fluid/framework/framework.pb.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor_array.h"
|
|
|
|
|
#include "paddle/fluid/platform/place.h"
|
|
|
|
|
|
|
|
|
@ -28,6 +29,20 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
class ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit ReaderBase(const std::vector<DDim>& shapes,
|
|
|
|
|
const std::vector<proto::VarType::Type>& var_types,
|
|
|
|
|
const std::vector<bool>& need_check_feed)
|
|
|
|
|
: shapes_(shapes),
|
|
|
|
|
var_types_(var_types),
|
|
|
|
|
need_check_feed_(need_check_feed) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(shapes_.size(), need_check_feed_.size(),
|
|
|
|
|
"Construct ReaderBase with mismatched sizes of shapes "
|
|
|
|
|
"and need_check_feed");
|
|
|
|
|
PADDLE_ENFORCE_EQ(var_types_.size(), need_check_feed_.size(),
|
|
|
|
|
"Construct ReaderBase with mismatched sizes of var_types "
|
|
|
|
|
"and need_check_feed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual void ReadNext(std::vector<LoDTensor>* out);
|
|
|
|
|
|
|
|
|
|
virtual void Shutdown();
|
|
|
|
@ -38,6 +53,18 @@ class ReaderBase {
|
|
|
|
|
// they are readers just before read op.
|
|
|
|
|
std::unordered_set<ReaderBase*> GetEndPoints();
|
|
|
|
|
|
|
|
|
|
// Returns the shapes of the feeded variables
|
|
|
|
|
const std::vector<DDim>& Shapes() const { return shapes_; }
|
|
|
|
|
|
|
|
|
|
// Returns the dtypes of the feeded variables
|
|
|
|
|
const std::vector<proto::VarType::Type>& VarTypes() const {
|
|
|
|
|
return var_types_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// For Backward compatibility, old fluid.layers.data doesn't check shape.
|
|
|
|
|
// This function returns whether you have the check shape for this Reader.
|
|
|
|
|
const std::vector<bool>& NeedCheckFeed() const { return need_check_feed_; }
|
|
|
|
|
|
|
|
|
|
virtual ~ReaderBase();
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -53,6 +80,17 @@ class ReaderBase {
|
|
|
|
|
|
|
|
|
|
mutable std::mutex mu_;
|
|
|
|
|
|
|
|
|
|
// The shapes of the feeded variables.
|
|
|
|
|
std::vector<DDim> shapes_;
|
|
|
|
|
|
|
|
|
|
// The dtypes of the feeded variables.
|
|
|
|
|
std::vector<proto::VarType::Type> var_types_;
|
|
|
|
|
|
|
|
|
|
// Whether to check the shape and dtype of feeded variables.
|
|
|
|
|
// For Backward compatibility, variables created by old API fluid.layers.data
|
|
|
|
|
// doesn't check shape but fluid.data checks.
|
|
|
|
|
std::vector<bool> need_check_feed_;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
friend class DecoratedReader;
|
|
|
|
|
// These methods can be only invoked inside DecoratedReader to record the
|
|
|
|
@ -67,7 +105,9 @@ class DecoratedReader : public ReaderBase,
|
|
|
|
|
public std::enable_shared_from_this<DecoratedReader> {
|
|
|
|
|
public:
|
|
|
|
|
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
|
|
|
|
|
: ReaderBase(), reader_(reader) {
|
|
|
|
|
: ReaderBase(reader->Shapes(), reader->VarTypes(),
|
|
|
|
|
reader->NeedCheckFeed()),
|
|
|
|
|
reader_(reader) {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(reader_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -89,7 +129,13 @@ class DecoratedReader : public ReaderBase,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// FileReader is just a conceptual class.
|
|
|
|
|
class FileReader : public ReaderBase {};
|
|
|
|
|
class FileReader : public ReaderBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit FileReader(const std::vector<DDim>& shapes,
|
|
|
|
|
const std::vector<proto::VarType::Type>& var_types,
|
|
|
|
|
const std::vector<bool>& need_check_feed)
|
|
|
|
|
: ReaderBase(shapes, var_types, need_check_feed) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// The ReaderHolder is used as reader' unified wrapper,
|
|
|
|
|
// making it easier to access different type reader in Variables.
|
|
|
|
@ -134,6 +180,16 @@ class ReaderHolder {
|
|
|
|
|
reader_->Start();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<DDim>& Shapes() const { return reader_->Shapes(); }
|
|
|
|
|
|
|
|
|
|
const std::vector<proto::VarType::Type>& VarTypes() const {
|
|
|
|
|
return reader_->VarTypes();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<bool>& NeedCheckFeed() const {
|
|
|
|
|
return reader_->NeedCheckFeed();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|