Merge pull request #15799 from sneaxiy/feature/decoupled_reader
Try to decouple reader with program_descmove-code
commit
4cc9809cae
@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/py_reader.h"
|
||||
#include <memory>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
PyReader::PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
|
||||
: framework::FileReader() {
|
||||
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
|
||||
queue_ = queue;
|
||||
}
|
||||
|
||||
void PyReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
||||
bool success;
|
||||
*out = queue_->Pop(&success);
|
||||
if (!success) out->clear();
|
||||
}
|
||||
|
||||
PyReader::~PyReader() { queue_->Close(); }
|
||||
|
||||
void PyReader::Shutdown() { queue_->Close(); }
|
||||
|
||||
void PyReader::Start() { queue_->ReOpen(); }
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,45 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
class PyReader : public framework::FileReader {
|
||||
public:
|
||||
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue);
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
||||
|
||||
~PyReader();
|
||||
|
||||
void Shutdown() override;
|
||||
|
||||
void Start() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<LoDTensorBlockingQueue> queue_;
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,161 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/pybind/reader_py.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
#include "paddle/fluid/operators/reader/buffered_reader.h"
|
||||
#include "paddle/fluid/operators/reader/py_reader.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
class MultiDeviceFeedReader {
|
||||
public:
|
||||
using ResultDictList =
|
||||
std::vector<std::unordered_map<std::string, framework::LoDTensor>>;
|
||||
|
||||
MultiDeviceFeedReader(
|
||||
const std::shared_ptr<operators::reader::LoDTensorBlockingQueue> &queue,
|
||||
const std::vector<std::string> &names,
|
||||
const std::vector<platform::Place> &dst_places, bool use_double_buffer)
|
||||
: queue_(queue),
|
||||
names_(names),
|
||||
pool_(new ::ThreadPool(dst_places.size())) {
|
||||
std::shared_ptr<framework::ReaderBase> reader(
|
||||
new operators::reader::PyReader(queue));
|
||||
|
||||
readers_.reserve(dst_places.size());
|
||||
for (auto &p : dst_places) {
|
||||
auto *holder = new framework::ReaderHolder();
|
||||
if (use_double_buffer) {
|
||||
holder->Reset(
|
||||
framework::MakeDecoratedReader<operators::reader::BufferedReader>(
|
||||
reader, p, 2));
|
||||
} else {
|
||||
if (platform::is_gpu_place(p)) {
|
||||
PADDLE_THROW(
|
||||
"Place cannot be CUDAPlace when use_double_buffer is False");
|
||||
}
|
||||
holder->Reset(reader);
|
||||
}
|
||||
readers_.emplace_back(holder);
|
||||
}
|
||||
|
||||
futures_.resize(dst_places.size());
|
||||
ret_.resize(dst_places.size());
|
||||
ReadAsync();
|
||||
}
|
||||
|
||||
ResultDictList ReadNext() {
|
||||
bool success = WaitFutures();
|
||||
|
||||
if (!success) {
|
||||
return {};
|
||||
}
|
||||
|
||||
ResultDictList result(ret_.size());
|
||||
for (size_t i = 0; i < ret_.size(); ++i) {
|
||||
for (size_t j = 0; j < names_.size(); ++j) {
|
||||
result[i].emplace(names_[j], std::move(ret_[i][j]));
|
||||
}
|
||||
}
|
||||
ReadAsync();
|
||||
return result;
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
Shutdown();
|
||||
Start();
|
||||
ReadAsync();
|
||||
}
|
||||
|
||||
~MultiDeviceFeedReader() {
|
||||
queue_->Close();
|
||||
pool_.reset();
|
||||
}
|
||||
|
||||
private:
|
||||
bool WaitFutures() {
|
||||
bool success = true;
|
||||
for (auto &f : futures_) {
|
||||
success &= f.get();
|
||||
}
|
||||
return success;
|
||||
}
|
||||
|
||||
void Shutdown() {
|
||||
for (auto &r : readers_) r->Shutdown();
|
||||
}
|
||||
|
||||
void Start() {
|
||||
for (auto &r : readers_) r->Start();
|
||||
}
|
||||
|
||||
void ReadAsync() {
|
||||
for (size_t i = 0; i < readers_.size(); ++i) {
|
||||
futures_[i] = pool_->enqueue([this, i] {
|
||||
readers_[i]->ReadNext(&ret_[i]);
|
||||
return !ret_[i].empty();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<operators::reader::LoDTensorBlockingQueue> queue_;
|
||||
std::vector<std::string> names_;
|
||||
std::unique_ptr<::ThreadPool> pool_;
|
||||
|
||||
std::vector<std::unique_ptr<framework::ReaderHolder>> readers_;
|
||||
|
||||
std::vector<std::future<bool>> futures_;
|
||||
std::vector<std::vector<framework::LoDTensor>> ret_;
|
||||
};
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void BindReader(py::module *module) {
|
||||
auto &m = *module;
|
||||
|
||||
namespace reader = ::paddle::operators::reader;
|
||||
|
||||
py::class_<framework::ReaderHolder>(m, "Reader", "")
|
||||
.def("start", &framework::ReaderHolder::Start)
|
||||
.def("reset", &framework::ReaderHolder::ResetAll);
|
||||
|
||||
py::class_<MultiDeviceFeedReader>(m, "MultiDeviceFeedReader", "")
|
||||
.def("read_next", &MultiDeviceFeedReader::ReadNext,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("reset", &MultiDeviceFeedReader::Reset,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def("create_py_reader",
|
||||
[](const std::shared_ptr<operators::reader::LoDTensorBlockingQueue>
|
||||
&queue,
|
||||
const std::vector<std::string> &names,
|
||||
const std::vector<platform::Place> &dst_places,
|
||||
bool use_double_buffer) {
|
||||
return new MultiDeviceFeedReader(queue, names, dst_places,
|
||||
use_double_buffer);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -0,0 +1,25 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
void BindReader(pybind11::module *module);
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,175 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import numpy as np
|
||||
import time
|
||||
import six
|
||||
import unittest
|
||||
|
||||
EPOCH_NUM = 60
|
||||
BATCH_SIZE = 32
|
||||
CLASS_NUM = 10
|
||||
|
||||
|
||||
def random_reader():
|
||||
np.random.seed(1)
|
||||
for i in range(BATCH_SIZE * 40):
|
||||
image = np.random.random([784])
|
||||
label = np.random.random_integers(low=0, high=CLASS_NUM - 1)
|
||||
yield image, label
|
||||
|
||||
|
||||
def simple_fc_net(places, use_legacy_py_reader, use_double_buffer):
|
||||
startup_prog = fluid.Program()
|
||||
main_prog = fluid.Program()
|
||||
startup_prog.random_seed = 1
|
||||
main_prog.random_seed = 1
|
||||
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main_prog, startup_prog):
|
||||
image = fluid.layers.data(
|
||||
name='image', shape=[784], dtype='float32')
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
py_reader = fluid.io.PyReader(
|
||||
feed_list=[image, label],
|
||||
capacity=4,
|
||||
iterable=not use_legacy_py_reader,
|
||||
use_double_buffer=use_double_buffer)
|
||||
hidden = image
|
||||
for hidden_size in [10, 20, 30]:
|
||||
hidden = fluid.layers.fc(
|
||||
hidden,
|
||||
size=hidden_size,
|
||||
act='tanh',
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=1.0)))
|
||||
|
||||
predict_label = fluid.layers.fc(hidden,
|
||||
size=CLASS_NUM,
|
||||
act='softmax')
|
||||
loss = fluid.layers.mean(
|
||||
fluid.layers.cross_entropy(
|
||||
input=predict_label, label=label))
|
||||
|
||||
optimizer = fluid.optimizer.Adam()
|
||||
optimizer.minimize(loss)
|
||||
return startup_prog, main_prog, py_reader, loss
|
||||
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
def run_main(self, use_legacy_py_reader, with_data_parallel, places,
|
||||
use_double_buffer):
|
||||
scope = fluid.Scope()
|
||||
with fluid.scope_guard(scope):
|
||||
startup_prog, main_prog, py_reader, loss = simple_fc_net(
|
||||
places, use_legacy_py_reader, use_double_buffer)
|
||||
|
||||
reader = paddle.batch(random_reader, batch_size=BATCH_SIZE)
|
||||
|
||||
ps = places if use_double_buffer else fluid.cpu_places(len(places))
|
||||
|
||||
py_reader.decorate_sample_list_generator(
|
||||
reader, places=ps if py_reader.iterable else None)
|
||||
|
||||
exe = fluid.Executor(place=places[0])
|
||||
exe.run(startup_prog)
|
||||
|
||||
prog = fluid.CompiledProgram(main_prog)
|
||||
if with_data_parallel:
|
||||
prog = prog.with_data_parallel(
|
||||
loss_name=loss.name, places=places)
|
||||
|
||||
step = 0
|
||||
step_list = []
|
||||
loss_list = []
|
||||
start_t = time.time()
|
||||
if not py_reader.iterable:
|
||||
for _ in six.moves.range(EPOCH_NUM):
|
||||
step = 0
|
||||
py_reader.start()
|
||||
while True:
|
||||
try:
|
||||
L, = exe.run(program=prog,
|
||||
fetch_list=[loss],
|
||||
use_program_cache=True)
|
||||
loss_list.append(np.mean(L))
|
||||
step += 1
|
||||
except fluid.core.EOFException:
|
||||
py_reader.reset()
|
||||
break
|
||||
step_list.append(step)
|
||||
else:
|
||||
for _ in six.moves.range(EPOCH_NUM):
|
||||
step = 0
|
||||
for d in py_reader():
|
||||
assert len(d) == len(places)
|
||||
for i, item in enumerate(d):
|
||||
image = item['image']
|
||||
label = item['label']
|
||||
assert image.shape() == [BATCH_SIZE, 784]
|
||||
assert label.shape() == [BATCH_SIZE, 1]
|
||||
assert image._place()._equals(ps[i])
|
||||
assert label._place()._equals(ps[i])
|
||||
L, = exe.run(program=prog,
|
||||
feed=d,
|
||||
fetch_list=[loss],
|
||||
use_program_cache=True)
|
||||
loss_list.append(np.mean(L))
|
||||
step += 1
|
||||
step_list.append(step)
|
||||
end_t = time.time()
|
||||
ret = {
|
||||
"time": end_t - start_t,
|
||||
"step": step_list,
|
||||
"loss": np.array(loss_list)
|
||||
}
|
||||
return ret
|
||||
|
||||
def prepare_places(self, with_data_parallel, with_cpu=True, with_gpu=True):
|
||||
places = []
|
||||
if with_cpu:
|
||||
places.append([fluid.CPUPlace()])
|
||||
if with_data_parallel:
|
||||
places.append([fluid.CPUPlace()] * 2)
|
||||
|
||||
if with_gpu and fluid.core.is_compiled_with_cuda():
|
||||
tmp = fluid.cuda_places()
|
||||
assert len(tmp) > 0, "no gpu detected"
|
||||
if with_data_parallel:
|
||||
places.append(tmp)
|
||||
places.append([tmp[0]])
|
||||
return places
|
||||
|
||||
def test_main(self):
|
||||
for with_data_parallel in [True, False]:
|
||||
for p in self.prepare_places(with_data_parallel):
|
||||
for use_double_buffer in [False, True]:
|
||||
results = []
|
||||
for use_legacy_py_reader in [False, True]:
|
||||
ret = self.run_main(
|
||||
use_legacy_py_reader=use_legacy_py_reader,
|
||||
with_data_parallel=with_data_parallel,
|
||||
places=p,
|
||||
use_double_buffer=use_double_buffer)
|
||||
results.append(ret)
|
||||
if not use_double_buffer:
|
||||
diff = np.max(
|
||||
np.abs(results[0]['loss'] - results[1]['loss']))
|
||||
self.assertLess(diff, 1e-3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,137 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import math
|
||||
import unittest
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
os.environ['CPU_NUM'] = '1'
|
||||
|
||||
|
||||
def random_reader(sample_num):
|
||||
def __impl__():
|
||||
for _ in range(sample_num):
|
||||
yield np.random.random(
|
||||
size=[784]).astype('float32'), np.random.random_integers(
|
||||
low=0, high=9, size=[1]).astype('int64')
|
||||
|
||||
return paddle.reader.cache(__impl__)
|
||||
|
||||
|
||||
class TestCaseBase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.batch_size = 32
|
||||
self.epoch_num = 2
|
||||
self.sample_num = 165
|
||||
|
||||
def generate_all_data(self, reader):
|
||||
ret = []
|
||||
for d in reader():
|
||||
slots = [[], []]
|
||||
for item in d:
|
||||
slots[0].append(item[0])
|
||||
slots[1].append(item[1])
|
||||
slots = [np.array(slot) for slot in slots]
|
||||
ret.append(slots)
|
||||
return ret
|
||||
|
||||
def run_main(self, reader, use_sample_generator, iterable, drop_last):
|
||||
image = fluid.layers.data(name='image', dtype='float32', shape=[784])
|
||||
label = fluid.layers.data(name='label', dtype='int64', shape=[1])
|
||||
py_reader = fluid.io.PyReader(
|
||||
feed_list=[image, label],
|
||||
capacity=16,
|
||||
iterable=iterable,
|
||||
use_double_buffer=False)
|
||||
|
||||
batch_reader = paddle.batch(reader, self.batch_size, drop_last)
|
||||
all_datas = self.generate_all_data(batch_reader)
|
||||
|
||||
if not use_sample_generator:
|
||||
py_reader.decorate_sample_list_generator(
|
||||
batch_reader, places=fluid.cpu_places())
|
||||
else:
|
||||
py_reader.decorate_sample_generator(
|
||||
reader, self.batch_size, drop_last, places=fluid.cpu_places())
|
||||
|
||||
if drop_last:
|
||||
batch_num = int(self.sample_num / self.batch_size)
|
||||
else:
|
||||
batch_num = math.ceil(float(self.sample_num) / self.batch_size)
|
||||
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
exe.run(fluid.default_startup_program())
|
||||
for _ in range(self.epoch_num):
|
||||
if py_reader.iterable:
|
||||
step = 0
|
||||
for data in py_reader():
|
||||
img, lbl = exe.run(feed=data, fetch_list=[image, label])
|
||||
self.assertArrayEqual(img, all_datas[step][0])
|
||||
self.assertArrayEqual(lbl, all_datas[step][1])
|
||||
step += 1
|
||||
self.assertEqual(step, len(all_datas))
|
||||
else:
|
||||
step = 0
|
||||
try:
|
||||
py_reader.start()
|
||||
while True:
|
||||
img, lbl = exe.run(fetch_list=[image, label])
|
||||
self.assertArrayEqual(img, all_datas[step][0])
|
||||
self.assertArrayEqual(lbl, all_datas[step][1])
|
||||
step += 1
|
||||
except fluid.core.EOFException:
|
||||
py_reader.reset()
|
||||
self.assertEqual(step, len(all_datas))
|
||||
break
|
||||
|
||||
def assertArrayEqual(self, arr1, arr2):
|
||||
self.assertEqual(arr1.shape, arr2.shape)
|
||||
self.assertTrue((arr1 == arr2).all())
|
||||
|
||||
def test_main(self):
|
||||
reader = random_reader(self.sample_num)
|
||||
for use_sample_generator in [False, True]:
|
||||
for iterable in [False, True]:
|
||||
for drop_last in [False, True]:
|
||||
with fluid.program_guard(fluid.Program(), fluid.Program()):
|
||||
self.run_main(reader, use_sample_generator, iterable,
|
||||
drop_last)
|
||||
|
||||
|
||||
class TestCase1(TestCaseBase):
|
||||
def setUp(self):
|
||||
self.batch_size = 32
|
||||
self.epoch_num = 10
|
||||
self.sample_num = 160
|
||||
|
||||
|
||||
class TestCase2(TestCaseBase):
|
||||
def setUp(self):
|
||||
self.batch_size = 32
|
||||
self.epoch_num = 2
|
||||
self.sample_num = 200
|
||||
|
||||
|
||||
class TestCase3(TestCaseBase):
|
||||
def setUp(self):
|
||||
self.batch_size = 32
|
||||
self.epoch_num = 2
|
||||
self.sample_num = 159
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue