Merge pull request #8830 from reyoung/feature/recordio_file_reader
Feature/recordio file readershanyi15-patch-2
commit
e13aec601a
@ -1,6 +1,24 @@
|
||||
cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader)
|
||||
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
|
||||
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry)
|
||||
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
|
||||
op_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS reader_op_registry)
|
||||
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op create_double_buffer_reader_op PARENT_SCOPE)
|
||||
set(LOCAL_READER_LIBS)
|
||||
|
||||
function(reader_library TARGET_NAME)
|
||||
set(oneValueArgs "")
|
||||
set(multiValueArgs SRCS DEPS)
|
||||
set(options "")
|
||||
set(common_deps reader_op_registry)
|
||||
cmake_parse_arguments(reader_library "${options}" "${oneValueArgs}"
|
||||
"${multiValueArgs}" ${ARGN})
|
||||
op_library(${TARGET_NAME} SRCS ${reader_library_SRCS} DEPS ${common_deps} ${reader_library_DEPS})
|
||||
set(LOCAL_READER_LIBS
|
||||
${TARGET_NAME}
|
||||
${LOCAL_READER_LIBS}
|
||||
PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
|
||||
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
|
||||
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
|
||||
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
|
||||
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
|
||||
# Export local libraries to parent
|
||||
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
|
||||
|
@ -0,0 +1,87 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
#include "paddle/fluid/recordio/scanner.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
class RecordIOFileReader : public framework::FileReader {
|
||||
public:
|
||||
RecordIOFileReader(const std::string& filename,
|
||||
const std::vector<framework::DDim>& shapes)
|
||||
: FileReader(shapes),
|
||||
scanner_(filename),
|
||||
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
|
||||
platform::CPUPlace())) {}
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override {
|
||||
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
|
||||
}
|
||||
|
||||
bool HasNext() const override { return scanner_.HasNext(); }
|
||||
|
||||
void ReInit() override { scanner_.Reset(); }
|
||||
|
||||
private:
|
||||
recordio::Scanner scanner_;
|
||||
const platform::DeviceContext& dev_ctx_;
|
||||
};
|
||||
|
||||
class CreateRecordIOReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
|
||||
const auto& ranks = Attr<std::vector<int>>("ranks");
|
||||
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
|
||||
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
||||
int(shape_concat.size()),
|
||||
"The accumulate of all ranks should be equal to the "
|
||||
"shape concat's length.");
|
||||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
||||
std::string filename = Attr<std::string>("filename");
|
||||
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(new RecordIOFileReader(filename, shapes));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
|
||||
public:
|
||||
CreateRecordIOReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: FileReaderMakerBase(op_proto, op_checker) {
|
||||
AddAttr<std::string>("filename", "The filename of record io reader");
|
||||
AddComment(R"DOC(
|
||||
CreateRecordIOReader Operator
|
||||
|
||||
Create a reader from a record io file
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace reader = paddle::operators::reader;
|
||||
|
||||
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
|
||||
reader::CreateRecordIOReaderOp,
|
||||
reader::CreateRecordIOReaderOpMaker);
|
@ -0,0 +1,70 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/pybind/recordio.h"
|
||||
#include <fstream>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/recordio/writer.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
class RecordIOWriter {
|
||||
public:
|
||||
RecordIOWriter(const std::string& filename, recordio::Compressor compressor,
|
||||
size_t max_num_record)
|
||||
: stream_(filename), writer_(&stream_, compressor, max_num_record) {}
|
||||
|
||||
void AppendTensor(const framework::LoDTensor& tensor) {
|
||||
tensors_.push_back(tensor);
|
||||
}
|
||||
|
||||
void CompleteAppendTensor() {
|
||||
auto& ctx =
|
||||
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
||||
framework::WriteToRecordIO(writer_, tensors_, ctx);
|
||||
tensors_.clear();
|
||||
}
|
||||
|
||||
void Close() {
|
||||
PADDLE_ENFORCE(tensors_.empty());
|
||||
writer_.Flush();
|
||||
stream_.close();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<framework::LoDTensor> tensors_;
|
||||
std::ofstream stream_;
|
||||
recordio::Writer writer_;
|
||||
};
|
||||
|
||||
void BindRecordIOWriter(py::module& m) {
|
||||
py::class_<RecordIOWriter> writer(m, "RecordIOWriter", "");
|
||||
py::enum_<recordio::Compressor>(writer, "Compressor", "")
|
||||
.value("Snappy", recordio::Compressor::kSnappy)
|
||||
.value("NoCompress", recordio::Compressor::kNoCompress);
|
||||
|
||||
writer
|
||||
.def("__init__",
|
||||
[](RecordIOWriter& self, const std::string& filename,
|
||||
recordio::Compressor compressor, size_t max_num_record) {
|
||||
new (&self) RecordIOWriter(filename, compressor, max_num_record);
|
||||
})
|
||||
.def("append_tensor", &RecordIOWriter::AppendTensor)
|
||||
.def("complete_append_tensor", &RecordIOWriter::CompleteAppendTensor)
|
||||
.def("close", &RecordIOWriter::Close);
|
||||
}
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -0,0 +1,26 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
extern void BindRecordIOWriter(py::module& m);
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/recordio/scanner.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace recordio {
|
||||
Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
|
||||
: stream_(std::move(stream)) {
|
||||
Reset();
|
||||
}
|
||||
|
||||
Scanner::Scanner(const std::string &filename) {
|
||||
stream_.reset(new std::ifstream(filename));
|
||||
Reset();
|
||||
}
|
||||
|
||||
void Scanner::Reset() {
|
||||
stream_->seekg(0, std::ios::beg);
|
||||
ParseNextChunk();
|
||||
}
|
||||
|
||||
std::string Scanner::Next() {
|
||||
PADDLE_ENFORCE(!eof_, "StopIteration");
|
||||
auto rec = cur_chunk_.Record(offset_++);
|
||||
if (offset_ == cur_chunk_.NumRecords()) {
|
||||
ParseNextChunk();
|
||||
}
|
||||
return rec;
|
||||
}
|
||||
|
||||
void Scanner::ParseNextChunk() {
|
||||
eof_ = !cur_chunk_.Parse(*stream_);
|
||||
offset_ = 0;
|
||||
}
|
||||
|
||||
bool Scanner::HasNext() const { return !eof_; }
|
||||
} // namespace recordio
|
||||
} // namespace paddle
|
@ -0,0 +1,44 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include "paddle/fluid/recordio/chunk.h"
|
||||
namespace paddle {
|
||||
namespace recordio {
|
||||
|
||||
class Scanner {
|
||||
public:
|
||||
explicit Scanner(std::unique_ptr<std::istream>&& stream);
|
||||
|
||||
explicit Scanner(const std::string& filename);
|
||||
|
||||
void Reset();
|
||||
|
||||
std::string Next();
|
||||
|
||||
bool HasNext() const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<std::istream> stream_;
|
||||
Chunk cur_chunk_;
|
||||
size_t offset_;
|
||||
bool eof_;
|
||||
|
||||
void ParseNextChunk();
|
||||
};
|
||||
} // namespace recordio
|
||||
} // namespace paddle
|
@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "paddle/fluid/recordio/writer.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
namespace paddle {
|
||||
namespace recordio {
|
||||
void Writer::Write(const std::string& record) {
|
||||
cur_chunk_.Add(record);
|
||||
if (cur_chunk_.NumRecords() >= max_num_records_in_chunk_) {
|
||||
Flush();
|
||||
}
|
||||
}
|
||||
|
||||
void Writer::Flush() {
|
||||
cur_chunk_.Write(stream_, compressor_);
|
||||
cur_chunk_.Clear();
|
||||
}
|
||||
|
||||
Writer::~Writer() {
|
||||
PADDLE_ENFORCE(cur_chunk_.Empty(), "Writer must be flushed when destroy.");
|
||||
}
|
||||
|
||||
} // namespace recordio
|
||||
} // namespace paddle
|
@ -0,0 +1,43 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/recordio/chunk.h"
|
||||
namespace paddle {
|
||||
namespace recordio {
|
||||
|
||||
class Writer {
|
||||
public:
|
||||
Writer(std::ostream* sout,
|
||||
Compressor compressor,
|
||||
size_t max_num_records_in_chunk = 1000)
|
||||
: stream_(*sout),
|
||||
max_num_records_in_chunk_(max_num_records_in_chunk),
|
||||
compressor_(compressor) {}
|
||||
|
||||
void Write(const std::string& record);
|
||||
|
||||
void Flush();
|
||||
|
||||
~Writer();
|
||||
|
||||
private:
|
||||
std::ostream& stream_;
|
||||
size_t max_num_records_in_chunk_;
|
||||
Chunk cur_chunk_;
|
||||
Compressor compressor_;
|
||||
};
|
||||
|
||||
} // namespace recordio
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue