merge develop to solve conflict, also fix API doc, test=develop (#18823)
parent
50582071dc
commit
5b6673c44d
@ -1,65 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
include (ExternalProject)
|
||||
|
||||
# NOTE: snappy is needed when linking with recordio
|
||||
|
||||
set(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy)
|
||||
set(SNAPPY_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy)
|
||||
set(SNAPPY_INCLUDE_DIR "${SNAPPY_INSTALL_DIR}/include" CACHE PATH "snappy include directory." FORCE)
|
||||
|
||||
if(WIN32)
|
||||
SET(SNAPPY_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4244 /wd4267")
|
||||
else()
|
||||
SET(SNAPPY_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(
|
||||
extern_snappy
|
||||
GIT_REPOSITORY "https://github.com/google/snappy"
|
||||
GIT_TAG "1.1.7"
|
||||
PREFIX ${SNAPPY_SOURCES_DIR}
|
||||
UPDATE_COMMAND ""
|
||||
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
|
||||
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
|
||||
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
|
||||
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
|
||||
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
|
||||
-DCMAKE_CXX_FLAGS=${SNAPPY_CMAKE_CXX_FLAGS}
|
||||
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
|
||||
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
|
||||
-DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR}
|
||||
-DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
-DBUILD_TESTING=OFF
|
||||
-DSNAPPY_BUILD_TESTS:BOOL=OFF
|
||||
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
|
||||
${EXTERNAL_OPTIONAL_ARGS}
|
||||
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${SNAPPY_INSTALL_DIR}
|
||||
-DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPY_INSTALL_DIR}/lib
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
|
||||
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
|
||||
)
|
||||
IF(WIN32)
|
||||
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/snappy.lib")
|
||||
else(WIN32)
|
||||
set(SNAPPY_LIBRARIES "${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
|
||||
endif (WIN32)
|
||||
|
||||
add_library(snappy STATIC IMPORTED GLOBAL)
|
||||
set_property(TARGET snappy PROPERTY IMPORTED_LOCATION ${SNAPPY_LIBRARIES})
|
||||
|
||||
include_directories(${SNAPPY_INCLUDE_DIR})
|
||||
add_dependencies(snappy extern_snappy)
|
@ -1,63 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
include (ExternalProject)
|
||||
|
||||
set(SNAPPYSTREAM_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy_stream)
|
||||
set(SNAPPYSTREAM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy_stream)
|
||||
set(SNAPPYSTREAM_INCLUDE_DIR "${SNAPPYSTREAM_INSTALL_DIR}/include" CACHE PATH "snappy stream include directory." FORCE)
|
||||
|
||||
if(WIN32)
|
||||
# Fix me, VS2015 come without VLA support
|
||||
set(SNAPPYSTREAM_LIBRARIES "${SNAPPYSTREAM_INSTALL_DIR}/lib/snappystream.lib")
|
||||
MESSAGE(WARNING, "In windows, snappystream has no compile support for windows,
|
||||
please build it manually and put it at " ${SNAPPYSTREAM_INSTALL_DIR})
|
||||
else(WIN32)
|
||||
set(SNAPPYSTREAM_LIBRARIES "${SNAPPYSTREAM_INSTALL_DIR}/lib/libsnappystream.a")
|
||||
|
||||
ExternalProject_Add(
|
||||
extern_snappystream
|
||||
GIT_REPOSITORY "https://github.com/hoxnox/snappystream.git"
|
||||
GIT_TAG "0.2.8"
|
||||
PREFIX ${SNAPPYSTREAM_SOURCES_DIR}
|
||||
UPDATE_COMMAND ""
|
||||
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
|
||||
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
|
||||
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
|
||||
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
|
||||
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
|
||||
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
|
||||
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
|
||||
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
|
||||
-DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR}
|
||||
-DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib
|
||||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
|
||||
-DSNAPPY_ROOT=${SNAPPY_INSTALL_DIR}
|
||||
${EXTERNAL_OPTIONAL_ARGS}
|
||||
CMAKE_CACHE_ARGS
|
||||
-DCMAKE_INSTALL_PREFIX:PATH=${SNAPPYSTREAM_INSTALL_DIR}
|
||||
-DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPYSTREAM_INSTALL_DIR}/lib
|
||||
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
|
||||
DEPENDS snappy
|
||||
)
|
||||
endif(WIN32)
|
||||
|
||||
add_library(snappystream STATIC IMPORTED GLOBAL)
|
||||
set_property(TARGET snappystream PROPERTY IMPORTED_LOCATION ${SNAPPYSTREAM_LIBRARIES})
|
||||
|
||||
include_directories(${SNAPPYSTREAM_INCLUDE_DIR}) # For snappysteam to include its own headers.
|
||||
include_directories(${THIRD_PARTY_PATH}/install) # For Paddle to include snappy stream headers.
|
||||
|
||||
add_dependencies(snappystream extern_snappystream)
|
@ -1,151 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
class BatchReader : public framework::DecoratedReader {
|
||||
public:
|
||||
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
|
||||
bool discard_leftover)
|
||||
: DecoratedReader(reader),
|
||||
batch_size_(static_cast<size_t>(batch_size)),
|
||||
discard_leftover_(discard_leftover) {
|
||||
buffer_.reserve(batch_size_);
|
||||
}
|
||||
|
||||
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
|
||||
|
||||
private:
|
||||
size_t batch_size_;
|
||||
bool discard_leftover_;
|
||||
std::vector<std::vector<framework::LoDTensor>> buffer_;
|
||||
};
|
||||
|
||||
class CreateBatchReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
if (out->Get() != nullptr) {
|
||||
return;
|
||||
}
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
out->Reset(framework::MakeDecoratedReader<BatchReader>(
|
||||
underlying_reader, Attr<int>("batch_size"),
|
||||
Attr<bool>("discard_leftover")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
|
||||
protected:
|
||||
void Apply() override {
|
||||
AddAttr<int>("batch_size",
|
||||
"How many instances the batch reader yields each time.")
|
||||
.GreaterThan(0);
|
||||
AddAttr<bool>("discard_leftover",
|
||||
"If true, the leftover instances that are not enough for a "
|
||||
"new batch will be discarded.")
|
||||
.SetDefault(true);
|
||||
AddComment(R"DOC(
|
||||
CreateBatchReader Operator
|
||||
|
||||
A batch reader takes another reader as its 'underlying reader',
|
||||
gathers the underlying reader's outputs and then yields them in batches.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
|
||||
buffer_.clear();
|
||||
buffer_.reserve(batch_size_);
|
||||
for (size_t i = 0; i < batch_size_; ++i) {
|
||||
buffer_.push_back(std::vector<framework::LoDTensor>());
|
||||
reader_->ReadNext(&buffer_.back());
|
||||
if (buffer_.back().empty()) {
|
||||
buffer_.pop_back();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (discard_leftover_ && buffer_.size() < batch_size_) {
|
||||
buffer_.clear();
|
||||
}
|
||||
// Concat instances
|
||||
out->clear();
|
||||
if (buffer_.empty()) {
|
||||
// if buffer_ is empty, the 'out' will return as an empty vector.
|
||||
return;
|
||||
}
|
||||
size_t out_num = buffer_[0].size();
|
||||
out->reserve(out_num);
|
||||
for (size_t j = 0; j < out_num; ++j) {
|
||||
// Merge shape and check date type
|
||||
auto batch_type = buffer_[0][j].type();
|
||||
framework::DDim batch_shape = buffer_[0][j].dims();
|
||||
for (size_t i = 1; i < buffer_.size(); ++i) {
|
||||
auto ins_type = buffer_[i][j].type();
|
||||
framework::DDim ins_shape = buffer_[i][j].dims();
|
||||
PADDLE_ENFORCE_EQ(batch_type, ins_type);
|
||||
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
|
||||
slice_ddim(ins_shape, 1, ins_shape.size()));
|
||||
PADDLE_ENFORCE_GT(ins_shape[0], 0);
|
||||
batch_shape[0] += ins_shape[0];
|
||||
}
|
||||
|
||||
framework::LoDTensor out_tensor;
|
||||
out_tensor.Resize(batch_shape);
|
||||
out_tensor.mutable_data(platform::CPUPlace(), batch_type);
|
||||
int64_t dst_offset = 0;
|
||||
|
||||
// Merge lod and data
|
||||
framework::LoD batch_lod;
|
||||
for (size_t i = 0; i < buffer_.size(); ++i) {
|
||||
framework::DDim ins_shape = buffer_[i][j].dims();
|
||||
framework::LoD ins_lod = buffer_[i][j].lod();
|
||||
if (i == 0) {
|
||||
batch_lod = ins_lod;
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size());
|
||||
for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) {
|
||||
auto& lod_level = batch_lod[level_idx];
|
||||
for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) {
|
||||
lod_level.push_back(ins_lod[level_idx][k] + lod_level.back());
|
||||
}
|
||||
}
|
||||
}
|
||||
auto dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
|
||||
TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst);
|
||||
dst_offset += ins_shape[0];
|
||||
}
|
||||
out_tensor.set_lod(batch_lod);
|
||||
out->push_back(out_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators::reader;
|
||||
REGISTER_DECORATED_READER_OPERATOR(create_batch_reader,
|
||||
ops::CreateBatchReaderOp,
|
||||
ops::CreateBatchReaderOpMaker);
|
@ -1,93 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
class MultiPassReader : public framework::DecoratedReader {
|
||||
public:
|
||||
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
|
||||
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
|
||||
|
||||
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
|
||||
reader_->ReadNext(out);
|
||||
if (out->empty() && pass_count_ < pass_num_ - 1) {
|
||||
reader_->Shutdown();
|
||||
reader_->Start();
|
||||
reader_->ReadNext(out);
|
||||
++pass_count_;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void StartImpl() override {
|
||||
pass_count_ = 0;
|
||||
reader_->Start();
|
||||
}
|
||||
|
||||
int pass_num_;
|
||||
mutable int pass_count_;
|
||||
};
|
||||
|
||||
class CreateMultiPassReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
auto* out = detail::Ref(scope.FindVar(Output("Out")))
|
||||
.GetMutable<framework::ReaderHolder>();
|
||||
if (out->Get() != nullptr) {
|
||||
return;
|
||||
}
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
int pass_num = Attr<int>("pass_num");
|
||||
out->Reset(framework::MakeDecoratedReader<MultiPassReader>(
|
||||
underlying_reader, pass_num));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
|
||||
protected:
|
||||
void Apply() override {
|
||||
AddAttr<int>("pass_num", "The number of pass to run.").GreaterThan(0);
|
||||
AddComment(R"DOC(
|
||||
CreateMultiPassReader Operator
|
||||
|
||||
This operator creates a multi-pass reader. A multi-pass reader
|
||||
is used to yield data for several pass training continuously.
|
||||
It takes the number of passes to run as one of its attributes
|
||||
('pass_num'), and maintains a pass counter to record how many
|
||||
passes it has completed. When the underlying reader reaches the
|
||||
EOF, the multi-pass reader checks whether it has completed training
|
||||
of the given number of pass. If not, the underlying reader will
|
||||
be re-initialized and starts a new pass automatically.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators::reader;
|
||||
REGISTER_DECORATED_READER_OPERATOR(create_multi_pass_reader,
|
||||
ops::CreateMultiPassReaderOp,
|
||||
ops::CreateMultiPassReaderOpMaker);
|
@ -1,107 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
template <typename T>
|
||||
class RandomDataGenerator : public framework::FileReader {
|
||||
public:
|
||||
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low,
|
||||
float high)
|
||||
: framework::FileReader(), low_(low), high_(high), shapes_(shapes) {
|
||||
PADDLE_ENFORCE_LE(low, high,
|
||||
"'low' shouldn't be greater than 'high'.(%f vs %f)", low,
|
||||
high);
|
||||
unsigned int seed = std::random_device()();
|
||||
engine_.seed(seed);
|
||||
dist_ = std::uniform_real_distribution<float>(low_, high_);
|
||||
}
|
||||
|
||||
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
|
||||
out->clear();
|
||||
out->reserve(shapes_.size());
|
||||
for (const framework::DDim& shape : shapes_) {
|
||||
PADDLE_ENFORCE_GE(
|
||||
shape.size(), 2,
|
||||
"The rank of reader's output data should be 2 at least.(Now it's %d)",
|
||||
shape.size());
|
||||
framework::LoDTensor out_tensor;
|
||||
out_tensor.Resize(shape);
|
||||
T* data = out_tensor.mutable_data<T>(platform::CPUPlace());
|
||||
int64_t numel = framework::product(shape);
|
||||
for (int64_t i = 0; i < numel; ++i) {
|
||||
data[i] = dist_(engine_);
|
||||
}
|
||||
out->push_back(out_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float low_;
|
||||
float high_;
|
||||
std::minstd_rand engine_;
|
||||
std::uniform_real_distribution<float> dist_;
|
||||
std::vector<framework::DDim> shapes_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
|
||||
const auto& ranks = Attr<std::vector<int>>("ranks");
|
||||
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
|
||||
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
||||
static_cast<int>(shape_concat.size()),
|
||||
"The accumulate of all ranks should be equal to the "
|
||||
"shape concat's length.");
|
||||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(std::make_shared<RandomDataGenerator<T>>(
|
||||
shapes, Attr<float>("low"), Attr<float>("high")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase {
|
||||
protected:
|
||||
void Apply() override {
|
||||
AddAttr<float>("low", "The lower bound of reader's uniform distribution.");
|
||||
AddAttr<float>("high", "The upper bound of reader's uniform distribution.");
|
||||
AddComment(R"DOC(
|
||||
CreateRandomDataGenerator Operator
|
||||
|
||||
This Op creates a random reader.
|
||||
The reader generates random data instead of really reading from files.
|
||||
Generated data follow an uniform distribution between 'low' and 'high'.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators::reader;
|
||||
REGISTER_FILE_READER_OPERATOR(create_random_data_generator,
|
||||
ops::CreateRandomDataGeneratorOp<float>,
|
||||
ops::CreateRandomDataGeneratorOpMaker);
|
@ -1,93 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
#include "paddle/fluid/platform/lock_guard_ptr.h"
|
||||
#include "paddle/fluid/recordio/scanner.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
template <bool ThreadSafe>
|
||||
class RecordIOFileReader : public framework::FileReader {
|
||||
public:
|
||||
explicit RecordIOFileReader(const std::string& filename)
|
||||
: scanner_(filename),
|
||||
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
|
||||
platform::CPUPlace())) {
|
||||
if (ThreadSafe) {
|
||||
mutex_.reset(new std::mutex());
|
||||
}
|
||||
LOG(INFO) << "Creating file reader" << filename;
|
||||
}
|
||||
|
||||
protected:
|
||||
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
|
||||
platform::LockGuardPtr<std::mutex> guard(mutex_);
|
||||
bool ok = framework::ReadFromRecordIO(&scanner_, dev_ctx_, out);
|
||||
if (!ok) {
|
||||
out->clear();
|
||||
}
|
||||
}
|
||||
|
||||
void StartImpl() override { scanner_.Reset(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<std::mutex> mutex_;
|
||||
recordio::Scanner scanner_;
|
||||
const platform::DeviceContext& dev_ctx_;
|
||||
};
|
||||
|
||||
class CreateRecordIOReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
std::string filename = Attr<std::string>("filename");
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
|
||||
out->Reset(std::make_shared<RecordIOFileReader<true>>(filename));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
|
||||
protected:
|
||||
void Apply() override {
|
||||
AddAttr<std::string>(
|
||||
"filename",
|
||||
"The filename of record file. This file will given to reader.");
|
||||
AddComment(R"DOC(
|
||||
Open a recordio file and return the reader object. The returned reader object
|
||||
is thread-safe.
|
||||
|
||||
NOTE: This is a very low-level API. It is used for debugging data file or
|
||||
training. Please use `open_files` instead of this API for production usage.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace reader = paddle::operators::reader;
|
||||
|
||||
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
|
||||
reader::CreateRecordIOReaderOp,
|
||||
reader::CreateRecordIOReaderOpMaker);
|
||||
|
||||
REGISTER_FILE_READER(recordio, reader::RecordIOFileReader<false>);
|
@ -1,124 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <random>
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
class ShuffleReader : public framework::DecoratedReader {
|
||||
public:
|
||||
ShuffleReader(const std::shared_ptr<ReaderBase>& reader, size_t buffer_size,
|
||||
size_t seed = 0)
|
||||
: DecoratedReader(reader), buffer_size_(buffer_size), seed_(seed) {
|
||||
VLOG(10) << "Create shuffle reader of " << reader_;
|
||||
if (seed_ == 0) {
|
||||
std::random_device device;
|
||||
seed_ = device();
|
||||
}
|
||||
ReloadBuffer();
|
||||
}
|
||||
|
||||
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
|
||||
out->clear();
|
||||
if (iteration_pos_ >= buffer_.size()) {
|
||||
VLOG(10) << "Resetting shuffle buffer";
|
||||
ReloadBuffer();
|
||||
if (buffer_.empty()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
*out = buffer_[iteration_pos_++];
|
||||
}
|
||||
|
||||
private:
|
||||
void ShutdownImpl() override {
|
||||
reader_->Shutdown();
|
||||
buffer_.clear();
|
||||
iteration_pos_ = 0;
|
||||
}
|
||||
|
||||
void StartImpl() override {
|
||||
reader_->Start();
|
||||
ReloadBuffer();
|
||||
}
|
||||
|
||||
void ReloadBuffer() {
|
||||
buffer_.clear();
|
||||
buffer_.reserve(buffer_size_);
|
||||
iteration_pos_ = 0;
|
||||
for (size_t i = 0; i < buffer_size_; ++i) {
|
||||
std::vector<framework::LoDTensor> ins;
|
||||
reader_->ReadNext(&ins);
|
||||
if (ins.empty()) {
|
||||
break;
|
||||
}
|
||||
buffer_.emplace_back(ins);
|
||||
}
|
||||
std::mt19937 g(seed_);
|
||||
std::shuffle(buffer_.begin(), buffer_.end(), g);
|
||||
seed_ = g(); // update seed_;
|
||||
VLOG(10) << "random buffer size = " << buffer_.size();
|
||||
}
|
||||
|
||||
size_t buffer_size_;
|
||||
std::vector<std::vector<framework::LoDTensor>> buffer_;
|
||||
|
||||
size_t iteration_pos_;
|
||||
size_t seed_;
|
||||
};
|
||||
|
||||
class CreateShuffleReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
auto* out = detail::Ref(scope.FindVar(Output("Out")))
|
||||
.GetMutable<framework::ReaderHolder>();
|
||||
if (out->Get() != nullptr) {
|
||||
return;
|
||||
}
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
out->Reset(framework::MakeDecoratedReader<ShuffleReader>(
|
||||
underlying_reader, static_cast<size_t>(Attr<int>("buffer_size"))));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
|
||||
protected:
|
||||
void Apply() override {
|
||||
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
|
||||
AddComment(R"DOC(
|
||||
CreateShuffleReader Operator
|
||||
|
||||
A shuffle reader takes another reader as its 'underlying reader'
|
||||
and yields the underlying reader's outputs in a shuffled order.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators::reader;
|
||||
REGISTER_DECORATED_READER_OPERATOR(create_shuffle_reader,
|
||||
ops::CreateShuffleReaderOp,
|
||||
ops::CreateShuffleReaderOpMaker);
|
File diff suppressed because it is too large
Load Diff
@ -1,88 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/pybind/recordio.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/recordio/writer.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
namespace {
|
||||
|
||||
class RecordIOWriter {
|
||||
public:
|
||||
RecordIOWriter(const std::string& filename, recordio::Compressor compressor,
|
||||
size_t max_num_record)
|
||||
: closed_(false),
|
||||
stream_(filename, std::ios::binary),
|
||||
writer_(&stream_, compressor, max_num_record) {}
|
||||
|
||||
void AppendTensor(const framework::LoDTensor& tensor) {
|
||||
tensors_.push_back(tensor);
|
||||
}
|
||||
|
||||
void CompleteAppendTensor() {
|
||||
auto& ctx =
|
||||
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
||||
framework::WriteToRecordIO(&writer_, tensors_, ctx);
|
||||
tensors_.clear();
|
||||
}
|
||||
|
||||
void Close() {
|
||||
PADDLE_ENFORCE(tensors_.empty());
|
||||
writer_.Flush();
|
||||
stream_.close();
|
||||
closed_ = true;
|
||||
}
|
||||
|
||||
~RecordIOWriter() {
|
||||
if (!closed_) {
|
||||
Close();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool closed_;
|
||||
std::vector<framework::LoDTensor> tensors_;
|
||||
std::ofstream stream_;
|
||||
recordio::Writer writer_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void BindRecordIOWriter(py::module* m) {
|
||||
py::class_<RecordIOWriter> writer(*m, "RecordIOWriter", "");
|
||||
py::enum_<recordio::Compressor>(writer, "Compressor", "")
|
||||
.value("Snappy", recordio::Compressor::kSnappy)
|
||||
.value("NoCompress", recordio::Compressor::kNoCompress);
|
||||
|
||||
writer
|
||||
.def("__init__",
|
||||
[](RecordIOWriter& self, const std::string& filename,
|
||||
recordio::Compressor compressor, size_t max_num_record) {
|
||||
new (&self) RecordIOWriter(filename, compressor, max_num_record);
|
||||
})
|
||||
.def("append_tensor", &RecordIOWriter::AppendTensor)
|
||||
.def("complete_append_tensor", &RecordIOWriter::CompleteAppendTensor)
|
||||
.def("close", &RecordIOWriter::Close);
|
||||
}
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -1,27 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
void BindRecordIOWriter(py::module* m);
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -1,9 +0,0 @@
|
||||
# internal library.
|
||||
cc_library(header SRCS header.cc)
|
||||
cc_test(header_test SRCS header_test.cc DEPS header)
|
||||
cc_library(chunk SRCS chunk.cc DEPS snappystream snappy header zlib)
|
||||
cc_test(chunk_test SRCS chunk_test.cc DEPS chunk)
|
||||
cc_library(writer SRCS writer.cc DEPS chunk)
|
||||
cc_library(scanner SRCS scanner.cc DEPS chunk)
|
||||
cc_test(writer_scanner_test SRCS writer_scanner_test.cc DEPS writer scanner)
|
||||
cc_library(recordio DEPS chunk header writer scanner)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue