commit
cd8700f172
@ -0,0 +1,187 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
class CustomReader : public framework::DecoratedReader {
|
||||
public:
|
||||
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
|
||||
const platform::Place& dev_place,
|
||||
const std::vector<std::string>& source_var_names,
|
||||
const std::vector<std::string>& sink_var_names)
|
||||
: DecoratedReader(reader),
|
||||
program_(*sub_block.Program()),
|
||||
sub_block_id_(sub_block.ID()),
|
||||
exe_(framework::Executor(dev_place)),
|
||||
source_var_names_(source_var_names),
|
||||
sink_var_names_(sink_var_names) {}
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
||||
|
||||
private:
|
||||
const framework::ProgramDesc program_;
|
||||
int sub_block_id_;
|
||||
framework::Executor exe_;
|
||||
|
||||
std::vector<std::string> source_var_names_;
|
||||
std::vector<std::string> sink_var_names_;
|
||||
};
|
||||
|
||||
class CreateCustomReaderOp : public framework::OperatorBase {
|
||||
public:
|
||||
using framework::OperatorBase::OperatorBase;
|
||||
|
||||
private:
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
auto* out = scope.FindVar(Output("Out"))
|
||||
->template GetMutable<framework::ReaderHolder>();
|
||||
auto* sub_block = Attr<framework::BlockDesc*>("sub_block");
|
||||
if (out->Get() != nullptr) {
|
||||
return;
|
||||
}
|
||||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
||||
->Get<framework::ReaderHolder>();
|
||||
out->Reset(
|
||||
new CustomReader(underlying_reader.Get(), *sub_block, dev_place,
|
||||
Attr<std::vector<std::string>>("source_var_names"),
|
||||
Attr<std::vector<std::string>>("sink_var_names")));
|
||||
}
|
||||
};
|
||||
|
||||
class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
|
||||
protected:
|
||||
void Apply() override {
|
||||
AddAttr<framework::BlockDesc*>(
|
||||
"sub_block", "The block to hold all preprocessing operators.");
|
||||
AddAttr<std::vector<std::string>>(
|
||||
"source_var_names",
|
||||
"Source variables are starting points of data preprocessing. They hold "
|
||||
"preprocessing's input tensors. Each source variable corresponds to "
|
||||
"one of underlying reader's output datas.");
|
||||
AddAttr<std::vector<std::string>>(
|
||||
"sink_var_names",
|
||||
"Sink variables are ending points of data preprocessing. They hold "
|
||||
"preprocessing's output tensors. Each sink variable corresponds to "
|
||||
"one of custom reader's output datas.");
|
||||
AddComment(R"DOC(
|
||||
CreateCustomReader Operator
|
||||
|
||||
A custom reader can be used for input data preprocessing.
|
||||
A custom reader holds its own sub-block, which will be executed in its
|
||||
'ReadNext()' function. Users can configurate their own preprocessing
|
||||
pipelines by inserting operators into custom reader's sub-block.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CustomReaderInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(!ctx->IsRuntime(),
|
||||
"'CustomReaderInferShape' should only be invoked during "
|
||||
"compile time.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"The output decorated reader should not be null.");
|
||||
const auto* sub_block =
|
||||
ctx->Attrs().Get<framework::BlockDesc*>("sub_block");
|
||||
const auto sink_var_names =
|
||||
ctx->Attrs().Get<std::vector<std::string>>("sink_var_names");
|
||||
std::vector<std::vector<int64_t>> res_dims;
|
||||
std::vector<int32_t> res_lod_levels;
|
||||
for (const std::string& var_name : sink_var_names) {
|
||||
auto* sink_var = sub_block->FindVar(var_name);
|
||||
PADDLE_ENFORCE_NOT_NULL(sink_var);
|
||||
res_dims.emplace_back(sink_var->GetShape());
|
||||
res_lod_levels.push_back(sink_var->GetLoDLevel());
|
||||
}
|
||||
auto* out_reader =
|
||||
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
||||
out_reader->SetShapes(res_dims);
|
||||
out_reader->SetLoDLevels(res_lod_levels);
|
||||
}
|
||||
};
|
||||
|
||||
class CustomReaderInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc& op_desc,
|
||||
framework::BlockDesc* block) const override {
|
||||
framework::VarDesc* out_reader = block->FindVar(op_desc.Output("Out")[0]);
|
||||
PADDLE_ENFORCE_NOT_NULL(out_reader);
|
||||
out_reader->SetType(framework::proto::VarType::READER);
|
||||
|
||||
auto sink_var_names =
|
||||
boost::get<std::vector<std::string>>(op_desc.GetAttr("sink_var_names"));
|
||||
const auto* sub_block =
|
||||
boost::get<framework::BlockDesc*>(op_desc.GetAttr("sub_block"));
|
||||
std::vector<framework::proto::VarType::Type> res_data_types;
|
||||
for (const std::string& var_name : sink_var_names) {
|
||||
framework::VarDesc* var = sub_block->FindVar(var_name);
|
||||
PADDLE_ENFORCE_NOT_NULL(var);
|
||||
res_data_types.emplace_back(var->GetDataType());
|
||||
}
|
||||
out_reader->SetDataTypes(res_data_types);
|
||||
}
|
||||
};
|
||||
|
||||
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
||||
out->clear();
|
||||
std::vector<framework::LoDTensor> underlying_outs;
|
||||
reader_->ReadNext(&underlying_outs);
|
||||
if (underlying_outs.empty()) {
|
||||
// There is not next data.
|
||||
return;
|
||||
}
|
||||
PADDLE_ENFORCE(source_var_names_.size() == underlying_outs.size(),
|
||||
"The size of source_var_names(%d) and the size of "
|
||||
"underlying_outs(%d) are not consistent. Each feeding element "
|
||||
"must have its own source variable.",
|
||||
source_var_names_.size(), underlying_outs.size());
|
||||
// The scope for CustomReader's sub-block should be independent and shouldn't
|
||||
// be any other computation scope's child. Otherwise, data preprocessing and
|
||||
// compution cannot be concurrent.
|
||||
framework::Scope scope;
|
||||
// 1. Copy LoDTensors from underlying reader's output to source variables.
|
||||
for (size_t i = 0; i < source_var_names_.size(); ++i) {
|
||||
framework::Variable* var = scope.Var(source_var_names_[i]);
|
||||
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
|
||||
tensor->ShareDataWith(underlying_outs[i]);
|
||||
tensor->set_lod(underlying_outs[i].lod());
|
||||
}
|
||||
// 2. Run the sub-block.
|
||||
exe_.Run(program_, &scope, sub_block_id_, false, true);
|
||||
// 3. Copy LoDTensors from sink variables to out.
|
||||
out->resize(sink_var_names_.size());
|
||||
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
|
||||
const auto& tensor = detail::Ref(scope.FindVar(sink_var_names_[i]))
|
||||
.Get<framework::LoDTensor>();
|
||||
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators::reader;
|
||||
REGISTER_OPERATOR(create_custom_reader, ops::CreateCustomReaderOp,
|
||||
ops::CreateCustomReaderOpMaker, ops::CustomReaderInferShape,
|
||||
ops::CustomReaderInferVarType,
|
||||
paddle::framework::EmptyGradOpMaker)
|
@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.dataset.mnist as mnist
|
||||
|
||||
|
||||
class TestPreprocessor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
reader = paddle.batch(mnist.train(), batch_size=32)
|
||||
feeder = fluid.DataFeeder(
|
||||
feed_list=[ # order is image and label
|
||||
fluid.layers.data(
|
||||
name='image', shape=[784]),
|
||||
fluid.layers.data(
|
||||
name='label', shape=[1], dtype='int64'),
|
||||
],
|
||||
place=fluid.CPUPlace())
|
||||
self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file(
|
||||
'./mnist_for_preprocessor_test.recordio', reader, feeder)
|
||||
|
||||
def test_main(self):
|
||||
N = 10
|
||||
|
||||
img_expected_res = []
|
||||
lbl_expected_res = []
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
data_file = fluid.layers.io.open_recordio_file(
|
||||
'./mnist_for_preprocessor_test.recordio',
|
||||
shapes=[[-1, 784], [-1, 1]],
|
||||
lod_levels=[0, 0],
|
||||
dtypes=['float32', 'int64'])
|
||||
img, lbl = fluid.layers.io.read_file(data_file)
|
||||
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
place = fluid.CUDAPlace(0)
|
||||
else:
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
for _ in range(N):
|
||||
img_v, lbl_v = exe.run(fetch_list=[img, lbl])
|
||||
img_expected_res.append(img_v / 2)
|
||||
lbl_expected_res.append(lbl_v + 1)
|
||||
|
||||
img_actual_res = []
|
||||
lbl_actual_res = []
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
data_file = fluid.layers.io.open_recordio_file(
|
||||
'./mnist_for_preprocessor_test.recordio',
|
||||
shapes=[[-1, 784], [-1, 1]],
|
||||
lod_levels=[0, 0],
|
||||
dtypes=['float32', 'int64'])
|
||||
preprocessor = fluid.layers.io.Preprocessor(reader=data_file)
|
||||
with preprocessor.block():
|
||||
img, lbl = preprocessor.inputs()
|
||||
img_out = img / 2
|
||||
lbl_out = lbl + 1
|
||||
preprocessor.outputs(img_out, lbl_out)
|
||||
|
||||
data_file = fluid.layers.io.double_buffer(preprocessor())
|
||||
img, lbl = fluid.layers.io.read_file(data_file)
|
||||
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
place = fluid.CUDAPlace(0)
|
||||
else:
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
for _ in range(N):
|
||||
img_v, lbl_v = exe.run(fetch_list=[img, lbl])
|
||||
img_actual_res.append(img_v)
|
||||
lbl_actual_res.append(lbl_v)
|
||||
|
||||
for idx in range(N):
|
||||
np.allclose(img_expected_res[idx], img_actual_res[idx])
|
||||
np.allclose(lbl_expected_res[idx], lbl_actual_res[idx])
|
Loading…
Reference in new issue