|
|
|
@ -18,6 +18,7 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace reader {
|
|
|
|
|
template <bool ThreadSafe>
|
|
|
|
|
class RecordIOFileReader : public framework::FileReader {
|
|
|
|
|
public:
|
|
|
|
|
RecordIOFileReader(const std::string& filename,
|
|
|
|
@ -26,11 +27,19 @@ class RecordIOFileReader : public framework::FileReader {
|
|
|
|
|
scanner_(filename),
|
|
|
|
|
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
|
|
|
|
|
platform::CPUPlace())) {
|
|
|
|
|
if (ThreadSafe) {
|
|
|
|
|
mutex_.reset(new std::mutex());
|
|
|
|
|
}
|
|
|
|
|
LOG(INFO) << "Creating file reader" << filename;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReadNext(std::vector<framework::LoDTensor>* out) override {
|
|
|
|
|
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
|
|
|
|
|
if (ThreadSafe) {
|
|
|
|
|
std::lock_guard<std::mutex> guard(*mutex_);
|
|
|
|
|
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
|
|
|
|
|
} else {
|
|
|
|
|
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasNext() const override { return scanner_.HasNext(); }
|
|
|
|
@ -38,6 +47,7 @@ class RecordIOFileReader : public framework::FileReader {
|
|
|
|
|
void ReInit() override { scanner_.Reset(); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<std::mutex> mutex_;
|
|
|
|
|
recordio::Scanner scanner_;
|
|
|
|
|
const platform::DeviceContext& dev_ctx_;
|
|
|
|
|
};
|
|
|
|
@ -61,7 +71,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
|
->template GetMutable<framework::ReaderHolder>();
|
|
|
|
|
out->Reset(new RecordIOFileReader(filename, shapes));
|
|
|
|
|
out->Reset(new RecordIOFileReader<true>(filename, shapes));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|