commit
10343123e3
@ -1,246 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
static std::vector<framework::DDim> RestoreShapes(
|
||||
const std::vector<int>& shape_concat, const std::vector<int>& ranks) {
|
||||
std::vector<framework::DDim> res;
|
||||
int offset = 0;
|
||||
for (int len : ranks) {
|
||||
auto start_it = shape_concat.begin() + offset;
|
||||
auto end_it = start_it + len;
|
||||
res.push_back(framework::make_ddim(std::vector<int>(start_it, end_it)));
|
||||
offset += len;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// general infershape for file readers
|
||||
class CreateFileReaderInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"The output file reader should not be null.");
|
||||
const auto shape_concat =
|
||||
ctx->Attrs().Get<std::vector<int>>("shape_concat");
|
||||
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
|
||||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
||||
ctx->SetReaderDims("Out", shapes);
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
lod_levels.size(), shapes.size(),
|
||||
"The number of 'lod_levels'(%d) doesn't match the number "
|
||||
"of 'shapes'(%d).",
|
||||
lod_levels.size(), shapes.size());
|
||||
framework::VarDesc* reader =
|
||||
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
||||
reader->SetLoDLevels(lod_levels);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// general infershape for decorated readers
|
||||
class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
|
||||
"Input(UnderlyingReader) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"The output decorated reader should not be null.");
|
||||
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
|
||||
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
|
||||
framework::VarDesc* out_reader =
|
||||
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
||||
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// general var type inference for file readers
|
||||
class CreateFileReaderInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const override {
|
||||
std::string reader_name = op_desc.Output("Out")[0];
|
||||
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
|
||||
reader->SetType(framework::proto::VarType::READER);
|
||||
}
|
||||
};
|
||||
|
||||
// general var type inference for decorated readers
|
||||
class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const override {
|
||||
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
|
||||
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
|
||||
std::string out_reader_name = op_desc.Output("Out")[0];
|
||||
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
|
||||
out_reader->SetType(framework::proto::VarType::READER);
|
||||
out_reader->SetDataTypes(in_reader->GetDataTypes());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
|
||||
const auto& ranks = Attr<std::vector<int>>("ranks");
|
||||
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
|
||||
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
||||
int(shape_concat.size()),
|
||||
"The accumulate of all ranks should be equal to the "
|
||||
"shape concat's length.");
|
||||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(new framework::RandomDataGenerator<T>(shapes, Attr<float>("min"),
|
||||
Attr<float>("max")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateRandomDataGeneratorOpMaker
|
||||
: public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
||||
AddOutput("Out", "(ReaderHolder) The created random reader.");
|
||||
AddAttr<std::vector<int>>("shape_concat",
|
||||
"The concat of all data's shapes.");
|
||||
AddAttr<std::vector<int>>(
|
||||
"ranks",
|
||||
"The ranks of each data."
|
||||
"e.g."
|
||||
"shape_concat = [2,3,4,5,6]"
|
||||
"ranks = [3,2]"
|
||||
"It means the reader will generate two data each time,"
|
||||
"whose shapes are [2,3,4] and [5,6] respectively.");
|
||||
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
|
||||
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
|
||||
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
|
||||
AddComment(R"DOC(
|
||||
CreateRandomDataGenerator Operator
|
||||
|
||||
This Op creates a random reader.
|
||||
The reader generates random data instead of really reading from files.
|
||||
Generated data follow an uniform distribution between 'min' and 'max'.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CreateShuffleReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(new framework::ShuffleReader(underlying_reader.Get(),
|
||||
Attr<int>("buffer_size")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
||||
AddInput(
|
||||
"UnderlyingReader",
|
||||
"(ReaderHolder) The underlying reader for creating a shuffle reader.");
|
||||
AddOutput("Out", "(ReaderHolder) The created shuffle reader.");
|
||||
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
|
||||
AddComment(R"DOC(
|
||||
CreateShuffleReader Operator
|
||||
|
||||
A shuffle reader takes another reader as its 'underlying reader'
|
||||
and yields the underlying reader's outputs in a shuffled order.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CreateBatchReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(new framework::BatchReader(underlying_reader.Get(),
|
||||
Attr<int>("batch_size")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
||||
AddInput(
|
||||
"UnderlyingReader",
|
||||
"(ReaderHolder) The underlying reader for creating a batch reader.");
|
||||
AddOutput("Out", "(ReaderHolder) The created batch reader.");
|
||||
AddAttr<int>("batch_size",
|
||||
"How many instances the batch reader yields each time.")
|
||||
.GreaterThan(0);
|
||||
AddComment(R"DOC(
|
||||
CreateBatchReader Operator
|
||||
|
||||
A batch reader takes another reader as its 'underlying reader',
|
||||
gathers the underlying reader's outputs and then yields them in batches.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(create_random_data_generator,
|
||||
ops::CreateRandomDataGeneratorOp<float>,
|
||||
ops::CreateFileReaderInferShape,
|
||||
ops::CreateRandomDataGeneratorOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
ops::CreateFileReaderInferVarType);
|
||||
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
|
||||
ops::CreateDecoratedReaderInferShape,
|
||||
ops::CreateShuffleReaderOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
ops::CreateDecoratedReaderInferVarType);
|
||||
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
|
||||
ops::CreateDecoratedReaderInferShape,
|
||||
ops::CreateBatchReaderOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
ops::CreateDecoratedReaderInferVarType);
|
@ -1 +1,3 @@
|
||||
if(WITH_DISTRIBUTE)
|
||||
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
|
||||
endif()
|
||||
|
@ -0,0 +1,5 @@
|
||||
cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader)
|
||||
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
|
||||
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry)
|
||||
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
|
||||
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op PARENT_SCOPE)
|
@ -0,0 +1,137 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
class BatchReader : public framework::DecoratedReader {
|
||||
public:
|
||||
BatchReader(ReaderBase* reader, int batch_size)
|
||||
: DecoratedReader(reader), batch_size_(batch_size) {
|
||||
buffer_.reserve(batch_size_);
|
||||
}
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
||||
|
||||
private:
|
||||
int batch_size_;
|
||||
std::vector<std::vector<framework::LoDTensor>> buffer_;
|
||||
};
|
||||
|
||||
class CreateBatchReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(
|
||||
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
|
||||
public:
|
||||
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: DecoratedReaderMakerBase(op_proto, op_checker) {
|
||||
AddAttr<int>("batch_size",
|
||||
"How many instances the batch reader yields each time.")
|
||||
.GreaterThan(0);
|
||||
AddComment(R"DOC(
|
||||
CreateBatchReader Operator
|
||||
|
||||
A batch reader takes another reader as its 'underlying reader',
|
||||
gathers the underlying reader's outputs and then yields them in batches.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
||||
buffer_.clear();
|
||||
buffer_.reserve(batch_size_);
|
||||
for (int i = 0; i < batch_size_; ++i) {
|
||||
if (reader_->HasNext()) {
|
||||
buffer_.push_back(std::vector<framework::LoDTensor>());
|
||||
reader_->ReadNext(&buffer_.back());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Concat instances
|
||||
out->clear();
|
||||
if (buffer_.empty()) {
|
||||
// if buffer_ is empty, the 'out' will return as an empty vector.
|
||||
return;
|
||||
}
|
||||
int out_num = buffer_[0].size();
|
||||
out->reserve(out_num);
|
||||
for (int j = 0; j < out_num; ++j) {
|
||||
// Merge shape and check date type
|
||||
std::type_index batch_type = buffer_[0][j].type();
|
||||
framework::DDim batch_shape = buffer_[0][j].dims();
|
||||
for (size_t i = 1; i < buffer_.size(); ++i) {
|
||||
std::type_index ins_type = buffer_[i][j].type();
|
||||
framework::DDim ins_shape = buffer_[i][j].dims();
|
||||
PADDLE_ENFORCE_EQ(batch_type, ins_type);
|
||||
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
|
||||
slice_ddim(ins_shape, 1, ins_shape.size()));
|
||||
PADDLE_ENFORCE_GT(ins_shape[0], 0);
|
||||
batch_shape[0] += ins_shape[0];
|
||||
}
|
||||
|
||||
framework::LoDTensor out_tensor;
|
||||
out_tensor.Resize(batch_shape);
|
||||
out_tensor.mutable_data(platform::CPUPlace(), batch_type);
|
||||
int64_t dst_offset = 0;
|
||||
|
||||
// Merge lod and data
|
||||
framework::LoD batch_lod;
|
||||
for (size_t i = 0; i < buffer_.size(); ++i) {
|
||||
framework::DDim ins_shape = buffer_[i][j].dims();
|
||||
framework::LoD ins_lod = buffer_[i][j].lod();
|
||||
if (i == 0) {
|
||||
batch_lod = ins_lod;
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size());
|
||||
for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) {
|
||||
auto& lod_level = batch_lod[level_idx];
|
||||
for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) {
|
||||
lod_level.push_back(ins_lod[level_idx][k] + lod_level.back());
|
||||
}
|
||||
}
|
||||
}
|
||||
auto dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
|
||||
TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst);
|
||||
dst_offset += ins_shape[0];
|
||||
}
|
||||
out_tensor.set_lod(batch_lod);
|
||||
out->push_back(out_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators::reader;
|
||||
REGISTER_DECORATED_READER_OPERATOR(create_batch_reader,
|
||||
ops::CreateBatchReaderOp,
|
||||
ops::CreateBatchReaderOpMaker);
|
@ -0,0 +1,110 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
template <typename T>
|
||||
class RandomDataGenerator : public framework::FileReader {
|
||||
public:
|
||||
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
|
||||
float max)
|
||||
: FileReader(shapes), min_(min), max_(max) {
|
||||
PADDLE_ENFORCE_LE(
|
||||
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
|
||||
unsigned int seed = std::random_device()();
|
||||
engine_.seed(seed);
|
||||
dist_ = std::uniform_real_distribution<float>(min_, max_);
|
||||
}
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override {
|
||||
out->clear();
|
||||
out->reserve(shapes_.size());
|
||||
for (const framework::DDim& shape : shapes_) {
|
||||
PADDLE_ENFORCE_GE(
|
||||
shape.size(), 2,
|
||||
"The rank of reader's output data should be 2 at least.(Now it's %d)",
|
||||
shape.size());
|
||||
framework::LoDTensor out_tensor;
|
||||
out_tensor.Resize(shape);
|
||||
T* data = out_tensor.mutable_data<T>(platform::CPUPlace());
|
||||
int64_t numel = framework::product(shape);
|
||||
for (int64_t i = 0; i < numel; ++i) {
|
||||
data[i] = dist_(engine_);
|
||||
}
|
||||
out->push_back(out_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
bool HasNext() const override { return true; }
|
||||
|
||||
void ReInit() override { return; }
|
||||
|
||||
private:
|
||||
float min_;
|
||||
float max_;
|
||||
std::minstd_rand engine_;
|
||||
std::uniform_real_distribution<float> dist_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
|
||||
const auto& ranks = Attr<std::vector<int>>("ranks");
|
||||
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
|
||||
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
||||
int(shape_concat.size()),
|
||||
"The accumulate of all ranks should be equal to the "
|
||||
"shape concat's length.");
|
||||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("min"),
|
||||
Attr<float>("max")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase {
|
||||
public:
|
||||
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: FileReaderMakerBase(op_proto, op_checker) {
|
||||
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
|
||||
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
|
||||
AddComment(R"DOC(
|
||||
CreateRandomDataGenerator Operator
|
||||
|
||||
This Op creates a random reader.
|
||||
The reader generates random data instead of really reading from files.
|
||||
Generated data follow an uniform distribution between 'min' and 'max'.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators::reader;
|
||||
REGISTER_FILE_READER_OPERATOR(create_random_data_generator,
|
||||
ops::CreateRandomDataGeneratorOp<float>,
|
||||
ops::CreateRandomDataGeneratorOpMaker);
|
@ -0,0 +1,97 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
class ShuffleReader : public framework::DecoratedReader {
|
||||
public:
|
||||
ShuffleReader(ReaderBase* reader, int buffer_size)
|
||||
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) {
|
||||
buffer_.reserve(buffer_size);
|
||||
}
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
||||
|
||||
private:
|
||||
int buffer_size_;
|
||||
std::vector<std::vector<framework::LoDTensor>> buffer_;
|
||||
size_t iteration_pos_;
|
||||
};
|
||||
|
||||
void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
||||
if (iteration_pos_ >= buffer_.size()) {
|
||||
// Reload buffer with new data
|
||||
buffer_.clear();
|
||||
buffer_.reserve(buffer_size_);
|
||||
for (int i = 0; i < buffer_size_; ++i) {
|
||||
if (reader_->HasNext()) {
|
||||
buffer_.push_back(std::vector<framework::LoDTensor>());
|
||||
reader_->ReadNext(&buffer_.back());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
|
||||
// optimize.
|
||||
std::random_shuffle(buffer_.begin(), buffer_.end());
|
||||
iteration_pos_ = 0;
|
||||
}
|
||||
out->clear();
|
||||
if (!buffer_.empty()) {
|
||||
std::swap(*out, buffer_[iteration_pos_++]);
|
||||
}
|
||||
// if buffer_ is empty, the 'out' will return as an empty vector.
|
||||
}
|
||||
|
||||
class CreateShuffleReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(
|
||||
new ShuffleReader(underlying_reader.Get(), Attr<int>("buffer_size")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
|
||||
public:
|
||||
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: DecoratedReaderMakerBase(op_proto, op_checker) {
|
||||
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
|
||||
AddComment(R"DOC(
|
||||
CreateShuffleReader Operator
|
||||
|
||||
A shuffle reader takes another reader as its 'underlying reader'
|
||||
and yields the underlying reader's outputs in a shuffled order.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators::reader;
|
||||
REGISTER_DECORATED_READER_OPERATOR(create_shuffle_reader,
|
||||
ops::CreateShuffleReaderOp,
|
||||
ops::CreateShuffleReaderOpMaker);
|
@ -0,0 +1,116 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "reader_op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
|
||||
const std::vector<int>& ranks) {
|
||||
std::vector<framework::DDim> res;
|
||||
int offset = 0;
|
||||
for (int len : ranks) {
|
||||
auto start_it = shape_concat.begin() + offset;
|
||||
auto end_it = start_it + len;
|
||||
res.push_back(framework::make_ddim(std::vector<int>(start_it, end_it)));
|
||||
offset += len;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
FileReaderMakerBase::FileReaderMakerBase(
|
||||
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
||||
AddOutput("Out", "(ReaderHolder) The created random reader.");
|
||||
AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes.");
|
||||
AddAttr<std::vector<int>>(
|
||||
"ranks",
|
||||
"The ranks of each data."
|
||||
"e.g."
|
||||
"shape_concat = [2,3,4,5,6]"
|
||||
"ranks = [3,2]"
|
||||
"It means the reader will generate two data each time,"
|
||||
"whose shapes are [2,3,4] and [5,6] respectively.");
|
||||
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
|
||||
}
|
||||
|
||||
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"The output file reader should not be null.");
|
||||
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat");
|
||||
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
|
||||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
||||
ctx->SetReaderDims("Out", shapes);
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
|
||||
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
|
||||
"The number of 'lod_levels'(%d) doesn't match the number "
|
||||
"of 'shapes'(%d).",
|
||||
lod_levels.size(), shapes.size());
|
||||
framework::VarDesc* reader =
|
||||
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
||||
reader->SetLoDLevels(lod_levels);
|
||||
}
|
||||
}
|
||||
|
||||
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const {
|
||||
std::string reader_name = op_desc.Output("Out")[0];
|
||||
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
|
||||
reader->SetType(framework::proto::VarType::READER);
|
||||
}
|
||||
|
||||
void DecoratedReaderInferShape::operator()(
|
||||
framework::InferShapeContext* ctx) const {
|
||||
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
|
||||
"Input(UnderlyingReader) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"The output decorated reader should not be null.");
|
||||
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
|
||||
|
||||
if (ctx->IsRuntime()) {
|
||||
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
|
||||
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
|
||||
framework::VarDesc* out_reader =
|
||||
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
||||
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
|
||||
}
|
||||
}
|
||||
void DecoratedReaderInferVarType::operator()(
|
||||
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
|
||||
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
|
||||
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
|
||||
std::string out_reader_name = op_desc.Output("Out")[0];
|
||||
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
|
||||
out_reader->SetType(framework::proto::VarType::READER);
|
||||
out_reader->SetDataTypes(in_reader->GetDataTypes());
|
||||
}
|
||||
|
||||
DecoratedReaderMakerBase::DecoratedReaderMakerBase(
|
||||
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
|
||||
framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
||||
AddInput("UnderlyingReader",
|
||||
"(ReaderHolder) The underlying reader for creating a batch reader.");
|
||||
AddOutput("Out", "(ReaderHolder) The created batch reader.");
|
||||
}
|
||||
|
||||
} // namespace reader
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,75 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
extern std::vector<framework::DDim> RestoreShapes(
|
||||
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
|
||||
|
||||
class FileReaderMakerBase : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
FileReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker);
|
||||
};
|
||||
|
||||
class FileReaderInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override;
|
||||
};
|
||||
|
||||
class FileReaderInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const override;
|
||||
};
|
||||
|
||||
// general infershape for decorated reader
|
||||
class DecoratedReaderInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override;
|
||||
};
|
||||
|
||||
// general var type inference for decorated reader
|
||||
class DecoratedReaderInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const override;
|
||||
};
|
||||
|
||||
class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
DecoratedReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker);
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
#define REGISTER_FILE_READER_OPERATOR(op_name, ...) \
|
||||
REGISTER_OPERATOR(op_name, __VA_ARGS__, \
|
||||
paddle::operators::reader::FileReaderInferShape, \
|
||||
paddle::framework::EmptyGradOpMaker, \
|
||||
paddle::operators::reader::FileReaderInferVarType)
|
||||
|
||||
#define REGISTER_DECORATED_READER_OPERATOR(op_name, ...) \
|
||||
REGISTER_OPERATOR(op_name, __VA_ARGS__, \
|
||||
paddle::operators::reader::DecoratedReaderInferShape, \
|
||||
paddle::framework::EmptyGradOpMaker, \
|
||||
paddle::operators::reader::DecoratedReaderInferVarType)
|
Loading…
Reference in new issue