Merge pull request #8009 from JiayiFeng/dev_reader
Fundamental Data Reading in C++emailweixu-patch-1
commit
812cf15196
@ -0,0 +1,122 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/framework/reader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
DDim ReaderBase::shape(size_t idx) const {
|
||||
PADDLE_ENFORCE_LT(
|
||||
idx, shapes_.size(),
|
||||
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
|
||||
shapes_.size());
|
||||
return shapes_[idx];
|
||||
}
|
||||
|
||||
void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
|
||||
if (iteration_pos_ >= buffer_.size()) {
|
||||
// Reload buffer with new data
|
||||
buffer_.clear();
|
||||
buffer_.reserve(buffer_size_);
|
||||
for (int i = 0; i < buffer_size_; ++i) {
|
||||
if (reader_->HasNext()) {
|
||||
buffer_.push_back(std::vector<LoDTensor>());
|
||||
reader_->ReadNext(&buffer_.back());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
|
||||
// optimize.
|
||||
std::random_shuffle(buffer_.begin(), buffer_.end());
|
||||
iteration_pos_ = 0;
|
||||
}
|
||||
out->clear();
|
||||
if (!buffer_.empty()) {
|
||||
std::swap(*out, buffer_[iteration_pos_++]);
|
||||
}
|
||||
// if buffer_ is empty, the 'out' will return as an empty vector.
|
||||
}
|
||||
|
||||
void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
|
||||
buffer_.clear();
|
||||
buffer_.reserve(batch_size_);
|
||||
for (int i = 0; i < batch_size_; ++i) {
|
||||
if (reader_->HasNext()) {
|
||||
buffer_.push_back(std::vector<LoDTensor>());
|
||||
reader_->ReadNext(&buffer_.back());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Concat instances
|
||||
out->clear();
|
||||
if (buffer_.empty()) {
|
||||
// if buffer_ is empty, the 'out' will return as an empty vector.
|
||||
return;
|
||||
}
|
||||
int out_num = buffer_[0].size();
|
||||
out->reserve(out_num);
|
||||
for (int j = 0; j < out_num; ++j) {
|
||||
// Merge shape and check date type
|
||||
std::type_index batch_type = buffer_[0][j].type();
|
||||
DDim batch_shape = buffer_[0][j].dims();
|
||||
for (size_t i = 1; i < buffer_.size(); ++i) {
|
||||
std::type_index ins_type = buffer_[i][j].type();
|
||||
DDim ins_shape = buffer_[i][j].dims();
|
||||
PADDLE_ENFORCE_EQ(batch_type, ins_type);
|
||||
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
|
||||
slice_ddim(ins_shape, 1, ins_shape.size()));
|
||||
PADDLE_ENFORCE_GT(ins_shape[0], 0);
|
||||
batch_shape[0] += ins_shape[0];
|
||||
}
|
||||
|
||||
LoDTensor out_tensor;
|
||||
out_tensor.Resize(batch_shape);
|
||||
out_tensor.mutable_data(platform::CPUPlace(), batch_type);
|
||||
int64_t dst_offset = 0;
|
||||
|
||||
// Merge lod and data
|
||||
LoD batch_lod;
|
||||
std::vector<size_t> top_level_lod({0});
|
||||
for (size_t i = 0; i < buffer_.size(); ++i) {
|
||||
DDim ins_shape = buffer_[i][j].dims();
|
||||
LoD ins_lod = buffer_[i][j].lod();
|
||||
if (i == 0) {
|
||||
batch_lod = ins_lod;
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size());
|
||||
for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) {
|
||||
auto& lod_level = batch_lod[level_idx];
|
||||
for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) {
|
||||
lod_level.push_back(ins_lod[level_idx][k] + lod_level.back());
|
||||
}
|
||||
}
|
||||
}
|
||||
top_level_lod.push_back(
|
||||
top_level_lod.back() +
|
||||
(ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1)));
|
||||
|
||||
Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
|
||||
Copy(buffer_[i][j], platform::CPUPlace(), &dst);
|
||||
dst_offset += ins_shape[0];
|
||||
}
|
||||
batch_lod.insert(batch_lod.begin(), top_level_lod);
|
||||
out_tensor.set_lod(batch_lod);
|
||||
out->push_back(out_tensor);
|
||||
}
|
||||
}
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,161 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/ddim.h"
|
||||
#include "paddle/framework/lod_tensor_array.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class ReaderBase {
|
||||
public:
|
||||
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
|
||||
PADDLE_ENFORCE(!shapes_.empty());
|
||||
}
|
||||
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
|
||||
virtual bool HasNext() const = 0;
|
||||
|
||||
virtual void ReInit() = 0;
|
||||
|
||||
DDim shape(size_t idx) const;
|
||||
std::vector<DDim> shapes() const { return shapes_; }
|
||||
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
|
||||
|
||||
virtual ~ReaderBase() {}
|
||||
|
||||
protected:
|
||||
std::vector<DDim> shapes_;
|
||||
};
|
||||
|
||||
class FileReader : public ReaderBase {
|
||||
public:
|
||||
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
|
||||
};
|
||||
|
||||
class DecoratedReader : public ReaderBase {
|
||||
public:
|
||||
explicit DecoratedReader(ReaderBase* reader)
|
||||
: ReaderBase(reader->shapes()), reader_(reader) {
|
||||
PADDLE_ENFORCE_NOT_NULL(reader_);
|
||||
}
|
||||
|
||||
bool HasNext() const override { return reader_->HasNext(); }
|
||||
|
||||
void ReInit() override { reader_->ReInit(); }
|
||||
|
||||
protected:
|
||||
ReaderBase* reader_;
|
||||
};
|
||||
|
||||
// file readers
|
||||
|
||||
template <typename T>
|
||||
class RandomDataGenerator : public FileReader {
|
||||
public:
|
||||
RandomDataGenerator(const std::vector<DDim>& shapes, float min, float max)
|
||||
: FileReader(shapes), min_(min), max_(max) {
|
||||
PADDLE_ENFORCE_LE(
|
||||
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
|
||||
unsigned int seed = std::random_device()();
|
||||
engine_.seed(seed);
|
||||
dist_ = std::uniform_real_distribution<float>(min_, max_);
|
||||
}
|
||||
|
||||
void ReadNext(std::vector<LoDTensor>* out) override {
|
||||
out->clear();
|
||||
out->reserve(shapes_.size());
|
||||
for (const DDim& shape : shapes_) {
|
||||
PADDLE_ENFORCE_GE(
|
||||
shape.size(), 2,
|
||||
"The rank of reader's output data should be 2 at least.(Now it's %d)",
|
||||
shape.size());
|
||||
LoDTensor out_tensor;
|
||||
out_tensor.Resize(shape);
|
||||
T* data = out_tensor.mutable_data<T>(platform::CPUPlace());
|
||||
int64_t numel = product(shape);
|
||||
for (int64_t i = 0; i < numel; ++i) {
|
||||
data[i] = dist_(engine_);
|
||||
}
|
||||
out->push_back(out_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
bool HasNext() const override { return true; }
|
||||
|
||||
void ReInit() override { return; }
|
||||
|
||||
private:
|
||||
float min_;
|
||||
float max_;
|
||||
std::minstd_rand engine_;
|
||||
std::uniform_real_distribution<float> dist_;
|
||||
};
|
||||
|
||||
// decorated readers
|
||||
|
||||
class ShuffleReader : public DecoratedReader {
|
||||
public:
|
||||
ShuffleReader(ReaderBase* reader, int buffer_size)
|
||||
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) {
|
||||
buffer_.reserve(buffer_size);
|
||||
}
|
||||
|
||||
void ReadNext(std::vector<LoDTensor>* out) override;
|
||||
|
||||
private:
|
||||
int buffer_size_;
|
||||
std::vector<std::vector<LoDTensor>> buffer_;
|
||||
size_t iteration_pos_;
|
||||
};
|
||||
|
||||
class BatchReader : public DecoratedReader {
|
||||
public:
|
||||
BatchReader(ReaderBase* reader, int batch_size)
|
||||
: DecoratedReader(reader), batch_size_(batch_size) {
|
||||
buffer_.reserve(batch_size_);
|
||||
}
|
||||
|
||||
void ReadNext(std::vector<LoDTensor>* out) override;
|
||||
|
||||
private:
|
||||
int batch_size_;
|
||||
std::vector<std::vector<LoDTensor>> buffer_;
|
||||
};
|
||||
|
||||
// The ReaderHolder is used as readers' unified wrapper,
|
||||
// making it easier to access different type readers in Variables.
|
||||
class ReaderHolder {
|
||||
public:
|
||||
void Reset(ReaderBase* reader) { reader_.reset(reader); }
|
||||
|
||||
ReaderBase* Get() const { return reader_.get(); }
|
||||
|
||||
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
|
||||
bool HasNext() const { return reader_->HasNext(); }
|
||||
void ReInit() { reader_->ReInit(); }
|
||||
|
||||
DDim shape(size_t idx) const { return reader_->shape(idx); }
|
||||
std::vector<DDim> shapes() const { return reader_->shapes(); }
|
||||
void set_shapes(const std::vector<DDim>& shapes) {
|
||||
reader_->set_shapes(shapes);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<ReaderBase> reader_;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,205 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/reader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
static std::vector<framework::DDim> RestoreShapes(
|
||||
const std::vector<int>& shape_concat, const std::vector<int>& ranks) {
|
||||
std::vector<framework::DDim> res;
|
||||
int offset = 0;
|
||||
for (int len : ranks) {
|
||||
auto start_it = shape_concat.begin() + offset;
|
||||
auto end_it = start_it + len;
|
||||
res.push_back(framework::make_ddim(std::vector<int>(start_it, end_it)));
|
||||
offset += len;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// general infershape for file readers
|
||||
class CreateFileReaderInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"The output file reader should not be null.");
|
||||
const auto shape_concat =
|
||||
ctx->Attrs().Get<std::vector<int>>("shape_concat");
|
||||
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
|
||||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
||||
ctx->SetReaderDims("Out", shapes);
|
||||
}
|
||||
};
|
||||
|
||||
// general infershape for decorated readers
|
||||
class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
|
||||
"Input(UnderlyingReader) should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"The output decorated reader should not be null.");
|
||||
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
|
||||
}
|
||||
};
|
||||
|
||||
// general var type inference for all readers
|
||||
class CreateReaderInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const override {
|
||||
std::string reader_name = op_desc.Output("Out")[0];
|
||||
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
|
||||
reader->SetType(framework::proto::VarDesc::READER);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
|
||||
const auto& ranks = Attr<std::vector<int>>("ranks");
|
||||
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
|
||||
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
||||
int(shape_concat.size()),
|
||||
"The accumulate of all ranks should be equal to the "
|
||||
"shape concat's length.");
|
||||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(new framework::RandomDataGenerator<T>(shapes, Attr<float>("min"),
|
||||
Attr<float>("max")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateRandomDataGeneratorOpMaker
|
||||
: public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
||||
AddOutput("Out", "(ReaderHolder) The created random reader.");
|
||||
AddAttr<std::vector<int>>("shape_concat",
|
||||
"The concat of all data's shapes.");
|
||||
AddAttr<std::vector<int>>(
|
||||
"ranks",
|
||||
"The ranks of each data."
|
||||
"e.g."
|
||||
"shape_concat = [2,3,4,5,6]"
|
||||
"ranks = [3,2]"
|
||||
"It means the reader will generate two data each time,"
|
||||
"whose shapes are [2,3,4] and [5,6] respectively.");
|
||||
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
|
||||
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
|
||||
AddComment(R"DOC(
|
||||
CreateRandomDataGenerator Operator
|
||||
|
||||
This Op creates a random reader.
|
||||
The reader generates random data instead of really reading from files.
|
||||
Generated data follow an uniform distribution between 'min' and 'max'.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CreateShuffleReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(new framework::ShuffleReader(underlying_reader.Get(),
|
||||
Attr<int>("buffer_size")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
||||
AddInput(
|
||||
"UnderlyingReader",
|
||||
"(ReaderHolder) The underlying reader for creating a shuffle reader.");
|
||||
AddOutput("Out", "(ReaderHolder) The created shuffle reader.");
|
||||
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
|
||||
AddComment(R"DOC(
|
||||
CreateShuffleReader Operator
|
||||
|
||||
A shuffle reader takes another reader as its 'underlying reader'
|
||||
and yields the underlying reader's outputs in a shuffled order.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CreateBatchReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
out->Reset(new framework::BatchReader(underlying_reader.Get(),
|
||||
Attr<int>("batch_size")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
||||
AddInput(
|
||||
"UnderlyingReader",
|
||||
"(ReaderHolder) The underlying reader for creating a batch reader.");
|
||||
AddOutput("Out", "(ReaderHolder) The created batch reader.");
|
||||
AddAttr<int>("batch_size",
|
||||
"How many instances the batch reader yields each time.")
|
||||
.GreaterThan(0);
|
||||
AddComment(R"DOC(
|
||||
CreateBatchReader Operator
|
||||
|
||||
A batch reader takes another reader as its 'underlying reader',
|
||||
gathers the underlying reader's outputs and then yields them in batches.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(create_random_data_generator,
|
||||
ops::CreateRandomDataGeneratorOp<float>,
|
||||
ops::CreateFileReaderInferShape,
|
||||
ops::CreateRandomDataGeneratorOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
ops::CreateReaderInferVarType);
|
||||
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
|
||||
ops::CreateDecoratedReaderInferShape,
|
||||
ops::CreateShuffleReaderOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
ops::CreateReaderInferVarType);
|
||||
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
|
||||
ops::CreateDecoratedReaderInferShape,
|
||||
ops::CreateBatchReaderOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker,
|
||||
ops::CreateReaderInferVarType);
|
@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/reader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class ReadInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Reader"),
|
||||
"The ReadOp must take a reader as input.");
|
||||
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
|
||||
"The ReadOp should be assigned with output.");
|
||||
std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader");
|
||||
std::vector<std::string> out_names = ctx->Outputs("Out");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
reader_dims.size(), out_names.size(),
|
||||
"The reader's dim number doesn't match the output number.");
|
||||
ctx->SetOutputsDim("Out", reader_dims);
|
||||
}
|
||||
};
|
||||
|
||||
class ReadInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const override {
|
||||
std::string reader_name = op_desc.Input("Reader")[0];
|
||||
std::vector<std::string> out_names = op_desc.Output("Out");
|
||||
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
|
||||
auto dtypes = reader->GetDataTypes();
|
||||
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
|
||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
|
||||
out.SetType(framework::proto::VarDesc::LOD_TENSOR);
|
||||
out.SetDataType(dtypes[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ReadOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
void Run(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
framework::ReaderHolder* reader =
|
||||
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
|
||||
if (!reader->HasNext()) {
|
||||
reader->ReInit();
|
||||
PADDLE_ENFORCE(
|
||||
reader->HasNext(),
|
||||
"Reader can not read the next data even it has been re-initialized.");
|
||||
}
|
||||
std::vector<std::string> out_arg_names = Outputs("Out");
|
||||
std::vector<framework::LoDTensor> ins;
|
||||
reader->ReadNext(&ins);
|
||||
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size());
|
||||
for (size_t i = 0; i < ins.size(); ++i) {
|
||||
auto* out =
|
||||
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>();
|
||||
out->ShareDataWith(ins[i]);
|
||||
out->set_lod(ins[i].lod());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
ReadOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
||||
AddInput("Reader", "(ReaderHolder) The executed reader.");
|
||||
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
|
||||
AddComment(R"DOC(
|
||||
Read Operator
|
||||
|
||||
Execute a given reader once and output data.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(read, ops::ReadOp, ops::ReadInferShape, ops::ReadOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker, ops::ReadInferVarType);
|
@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
import numpy as np
|
||||
|
||||
prog = fluid.framework.Program()
|
||||
block = prog.current_block()
|
||||
|
||||
random_reader = block.create_var(
|
||||
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
|
||||
random_reader.desc.set_lod_levels([0, 0])
|
||||
|
||||
create_random_data_generator_op = block.append_op(
|
||||
type="create_random_data_generator",
|
||||
outputs={"Out": random_reader},
|
||||
attrs={
|
||||
"shape_concat": [1, 2, 1, 1],
|
||||
"ranks": [2, 2],
|
||||
"min": 0.0,
|
||||
"max": 1.0
|
||||
})
|
||||
|
||||
out1 = block.create_var(
|
||||
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
|
||||
name="Out1",
|
||||
shape=[10, 2],
|
||||
dtype="float32",
|
||||
lod_level=1)
|
||||
out2 = block.create_var(
|
||||
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
|
||||
name="Out2",
|
||||
shape=[10, 1],
|
||||
dtype="float32",
|
||||
lod_level=1)
|
||||
|
||||
read_op = block.append_op(
|
||||
type="read",
|
||||
inputs={"Reader": random_reader},
|
||||
outputs={"Out": [out1, out2]})
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
[res1, res2] = exe.run(prog, fetch_list=[out1, out2])
|
||||
|
||||
if len(res1) == 0 or len(res2) == 0:
|
||||
exit(1)
|
||||
|
||||
exit(0)
|
Loading…
Reference in new issue