parent
fc87ef741b
commit
7160cb0f32
@ -0,0 +1,39 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/compose_reader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
ComposeReader::ComposeReader(
|
||||
const std::vector<std::shared_ptr<framework::ReaderBase>> &readers)
|
||||
: framework::MultiDecoratedReader(readers) {}
|
||||
|
||||
void ComposeReader::ReadNext(std::vector<framework::LoDTensor> *out) {
|
||||
out->clear();
|
||||
std::vector<framework::LoDTensor> each_ret;
|
||||
for (auto &r : readers_) {
|
||||
r->ReadNext(&each_ret);
|
||||
out->reserve(out->size() + each_ret.size());
|
||||
for (auto &data : each_ret) {
|
||||
out->emplace_back(std::move(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,34 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
class ComposeReader : public framework::MultiDecoratedReader {
|
||||
public:
|
||||
explicit ComposeReader(
|
||||
const std::vector<std::shared_ptr<framework::ReaderBase>> &readers);
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor> *out) override;
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,78 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/py_reader.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
PyReader::PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
|
||||
: framework::FileReader() {
|
||||
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
|
||||
queue_ = queue;
|
||||
}
|
||||
|
||||
void PyReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
||||
bool success;
|
||||
*out = queue_->Pop(&success);
|
||||
if (!success) out->clear();
|
||||
}
|
||||
|
||||
PyReader::~PyReader() { queue_->Close(); }
|
||||
|
||||
void PyReader::Shutdown() { queue_->Close(); }
|
||||
|
||||
void PyReader::Start() { queue_->ReOpen(); }
|
||||
|
||||
MultiQueuePyReader::MultiQueuePyReader(
|
||||
const std::vector<std::shared_ptr<LoDTensorBlockingQueue>>& queues)
|
||||
: queues_(queues) {
|
||||
PADDLE_ENFORCE(!queues_.empty());
|
||||
for (auto& q : queues_) {
|
||||
PADDLE_ENFORCE_NOT_NULL(q);
|
||||
}
|
||||
}
|
||||
|
||||
void MultiQueuePyReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
||||
auto idx = read_out_idx_.fetch_add(1) % queues_.size();
|
||||
for (size_t i = 0; i < queues_.size(); ++i) {
|
||||
*out = queues_[idx]->Pop();
|
||||
if (!out->empty()) return;
|
||||
idx = (idx + 1) % queues_.size();
|
||||
}
|
||||
}
|
||||
|
||||
MultiQueuePyReader::~MultiQueuePyReader() {
|
||||
for (auto& q : queues_) {
|
||||
q->Close();
|
||||
}
|
||||
}
|
||||
|
||||
void MultiQueuePyReader::Shutdown() {
|
||||
for (auto& q : queues_) {
|
||||
q->Close();
|
||||
}
|
||||
read_out_idx_.store(0, std::memory_order::memory_order_seq_cst);
|
||||
}
|
||||
|
||||
void MultiQueuePyReader::Start() {
|
||||
for (auto& q : queues_) {
|
||||
q->ReOpen();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
class PyReader : public framework::FileReader {
|
||||
public:
|
||||
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue);
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
||||
|
||||
~PyReader();
|
||||
|
||||
void Shutdown() override;
|
||||
|
||||
void Start() override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<LoDTensorBlockingQueue> queue_;
|
||||
};
|
||||
|
||||
class MultiQueuePyReader : public framework::FileReader {
|
||||
public:
|
||||
explicit MultiQueuePyReader(
|
||||
const std::vector<std::shared_ptr<LoDTensorBlockingQueue>>& queues);
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
||||
|
||||
~MultiQueuePyReader();
|
||||
|
||||
void Shutdown() override;
|
||||
|
||||
void Start() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<LoDTensorBlockingQueue>> queues_;
|
||||
std::atomic<size_t> read_out_idx_{0};
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,132 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/pybind/reader_py.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
#include "paddle/fluid/operators/reader/buffered_reader.h"
|
||||
#include "paddle/fluid/operators/reader/compose_reader.h"
|
||||
#include "paddle/fluid/operators/reader/py_reader.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
class FeedReader {
|
||||
using ResultDictList =
|
||||
std::vector<std::unordered_map<std::string, framework::LoDTensor>>;
|
||||
|
||||
public:
|
||||
FeedReader(std::unique_ptr<framework::ReaderHolder> reader,
|
||||
const std::vector<std::string> &names, size_t num_places,
|
||||
bool drop_last = true)
|
||||
: reader_(std::move(reader)),
|
||||
names_(names),
|
||||
num_places_(num_places),
|
||||
drop_last_(drop_last) {}
|
||||
|
||||
ResultDictList ReadNext() {
|
||||
std::vector<framework::LoDTensor> tensors;
|
||||
reader_->ReadNext(&tensors);
|
||||
if (tensors.empty()) return ResultDictList();
|
||||
|
||||
PADDLE_ENFORCE(tensors.size() % names_.size() == 0,
|
||||
"Tensor size: %d, names size: %d", tensors.size(),
|
||||
names_.size());
|
||||
|
||||
size_t read_place_num = tensors.size() / names_.size();
|
||||
|
||||
if (drop_last_ && read_place_num != num_places_) {
|
||||
return ResultDictList();
|
||||
}
|
||||
|
||||
ResultDictList ret(read_place_num);
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
ret[i / names_.size()].emplace(names_[i % names_.size()],
|
||||
std::move(tensors[i]));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void Start() { reader_->Start(); }
|
||||
|
||||
void Reset() { reader_->ResetAll(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<framework::ReaderHolder> reader_;
|
||||
std::vector<std::string> names_;
|
||||
size_t num_places_;
|
||||
bool drop_last_;
|
||||
};
|
||||
|
||||
static std::unique_ptr<framework::ReaderHolder> CreatePyReader(
|
||||
const std::vector<
|
||||
std::shared_ptr<operators::reader::LoDTensorBlockingQueue>> &queues,
|
||||
const std::vector<platform::Place> &dst_places) {
|
||||
std::shared_ptr<framework::ReaderBase> reader;
|
||||
if (queues.size() == 1) {
|
||||
reader.reset(new operators::reader::PyReader(queues[0]));
|
||||
} else {
|
||||
reader.reset(new operators::reader::MultiQueuePyReader(queues));
|
||||
}
|
||||
std::vector<std::shared_ptr<framework::ReaderBase>> buffered_reader;
|
||||
buffered_reader.reserve(dst_places.size());
|
||||
for (auto &p : dst_places) {
|
||||
buffered_reader.emplace_back(
|
||||
framework::MakeDecoratedReader<operators::reader::BufferedReader>(
|
||||
reader, p, 2));
|
||||
}
|
||||
reader = framework::MakeDecoratedReader<operators::reader::ComposeReader>(
|
||||
buffered_reader);
|
||||
|
||||
auto *holder = new framework::ReaderHolder();
|
||||
holder->Reset(reader);
|
||||
return std::unique_ptr<framework::ReaderHolder>(holder);
|
||||
}
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void BindReader(py::module *module) {
|
||||
auto &m = *module;
|
||||
|
||||
namespace reader = ::paddle::operators::reader;
|
||||
|
||||
py::class_<framework::ReaderHolder>(m, "Reader", "")
|
||||
.def("start", &framework::ReaderHolder::Start)
|
||||
.def("reset", &framework::ReaderHolder::ResetAll);
|
||||
|
||||
py::class_<FeedReader>(m, "FeedReader", "")
|
||||
.def("read_next", &FeedReader::ReadNext,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("start", &FeedReader::Start,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("reset", &FeedReader::Reset,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
m.def("create_py_reader",
|
||||
[](const std::vector<
|
||||
std::shared_ptr<operators::reader::LoDTensorBlockingQueue>>
|
||||
queues,
|
||||
const std::vector<std::string> &names,
|
||||
const std::vector<platform::Place> &dst_places, bool drop_last) {
|
||||
return new FeedReader(CreatePyReader(queues, dst_places), names,
|
||||
dst_places.size(), drop_last);
|
||||
},
|
||||
py::return_value_policy::take_ownership);
|
||||
}
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -0,0 +1,25 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
void BindReader(pybind11::module *module);
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
@ -0,0 +1,141 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import core
|
||||
import six
|
||||
import threading
|
||||
from .framework import Program, Variable, program_guard
|
||||
from .data_feeder import DataFeeder
|
||||
|
||||
__all__ = ['PyReader']
|
||||
|
||||
|
||||
def _convert_places(places):
|
||||
if not isinstance(places, (list, tuple)):
|
||||
places = [places]
|
||||
|
||||
ret = []
|
||||
for p in places:
|
||||
if not isinstance(p, core.Place):
|
||||
tmp = core.Place()
|
||||
tmp.set_place(p)
|
||||
p = tmp
|
||||
|
||||
ret.append(p)
|
||||
return ret
|
||||
|
||||
|
||||
class PyReader(object):
|
||||
def __init__(self, feed_list, places, capacity, multi_queue=True):
|
||||
self._tensor_reader = None
|
||||
self._thread = None
|
||||
|
||||
# TODO(zjl): to support drop_last = False
|
||||
self._drop_last = True
|
||||
|
||||
self._feed_list = feed_list
|
||||
self._var_names = [v.name for v in feed_list]
|
||||
|
||||
self._queues = []
|
||||
|
||||
self._places = _convert_places(places)
|
||||
|
||||
self._queue_capacity = capacity
|
||||
|
||||
queue_num = len(self._places) if multi_queue else 1
|
||||
for _ in six.moves.range(queue_num):
|
||||
self._queues.append(
|
||||
core.init_lod_tensor_blocking_queue(core.Variable(),
|
||||
self._queue_capacity))
|
||||
|
||||
self._reader = core.create_py_reader(self._queues, self._var_names,
|
||||
self._places, self._drop_last)
|
||||
self._exited = True
|
||||
|
||||
def __call__(self):
|
||||
assert self._tensor_reader is not None, \
|
||||
"Data source of PyReader has not set yet"
|
||||
|
||||
class Iterator(object):
|
||||
def __init__(self, reader):
|
||||
self._reader = reader
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
ret = self._reader._reader.read_next()
|
||||
if len(ret):
|
||||
return ret
|
||||
else:
|
||||
self._reader._restart_reader()
|
||||
self._reader._reader.reset()
|
||||
raise StopIteration
|
||||
|
||||
return Iterator(self)
|
||||
|
||||
def _restart_reader(self):
|
||||
if not self._exited:
|
||||
for q in self._queues:
|
||||
q.close()
|
||||
|
||||
self._thread.join()
|
||||
|
||||
def __thread_main__():
|
||||
queue_num = len(self._queues)
|
||||
idx = 0
|
||||
for tensors in self._tensor_reader():
|
||||
array = core.LoDTensorArray()
|
||||
for item in tensors:
|
||||
if not isinstance(item, core.LoDTensor):
|
||||
tmp = core.LoDTensor()
|
||||
tmp.set(item, core.CPUPlace())
|
||||
item = tmp
|
||||
|
||||
array.append(item)
|
||||
|
||||
if not self._queues[idx].push(array):
|
||||
break
|
||||
|
||||
idx = (idx + 1) % queue_num
|
||||
|
||||
for q in self._queues:
|
||||
q.close()
|
||||
|
||||
self._exited = True
|
||||
|
||||
self._thread = threading.Thread(target=__thread_main__)
|
||||
self._thread.daemon = True
|
||||
self._exited = False
|
||||
self._thread.start()
|
||||
|
||||
def set_numpy_reader(self, reader):
|
||||
assert self._tensor_reader is None, \
|
||||
"Cannot reset the data source of PyReader"
|
||||
with program_guard(Program(), Program()):
|
||||
feeder = DataFeeder(
|
||||
feed_list=self._feed_list, place=core.CPUPlace())
|
||||
paddle_reader = feeder.decorate_reader(reader, multi_devices=False)
|
||||
|
||||
def __tensor_reader_impl__():
|
||||
for slots in paddle_reader():
|
||||
yield [slots[var.name] for var in self._feed_list]
|
||||
|
||||
self.set_tensor_reader(__tensor_reader_impl__)
|
||||
|
||||
def set_tensor_reader(self, reader):
|
||||
assert self._tensor_reader is None, \
|
||||
"Cannot reset the data source of PyReader"
|
||||
self._tensor_reader = reader
|
||||
self._restart_reader()
|
@ -0,0 +1,157 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import numpy as np
|
||||
import time
|
||||
import six
|
||||
import unittest
|
||||
|
||||
EPOCH_NUM = 60
|
||||
BATCH_SIZE = 32
|
||||
CLASS_NUM = 10
|
||||
|
||||
|
||||
def random_reader():
|
||||
for i in range(BATCH_SIZE * 40):
|
||||
image = np.random.random([784])
|
||||
label = np.random.random_integers(low=0, high=CLASS_NUM - 1)
|
||||
yield image, label
|
||||
|
||||
|
||||
def simple_fc_net(places, use_legacy_py_reader):
|
||||
startup_prog = fluid.Program()
|
||||
main_prog = fluid.Program()
|
||||
startup_prog.random_seed = 1
|
||||
main_prog.random_seed = 1
|
||||
reader = paddle.batch(random_reader, batch_size=BATCH_SIZE)
|
||||
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main_prog, startup_prog):
|
||||
if not use_legacy_py_reader:
|
||||
image = fluid.layers.data(
|
||||
name='image', shape=[784], dtype='float32')
|
||||
label = fluid.layers.data(
|
||||
name='label', shape=[1], dtype='int64')
|
||||
py_reader = fluid.io.PyReader(
|
||||
feed_list=[image, label],
|
||||
places=places,
|
||||
capacity=4,
|
||||
multi_queue=False)
|
||||
py_reader.set_numpy_reader(reader)
|
||||
else:
|
||||
py_reader = fluid.layers.py_reader(
|
||||
capacity=4,
|
||||
shapes=[(-1, 784), (-1, 1)],
|
||||
dtypes=['float32', 'int64'])
|
||||
image, label = fluid.layers.read_file(py_reader)
|
||||
py_reader.decorate_paddle_reader(reader)
|
||||
|
||||
hidden = image
|
||||
for hidden_size in [10, 20, 30]:
|
||||
hidden = fluid.layers.fc(
|
||||
hidden,
|
||||
size=hidden_size,
|
||||
act='tanh',
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=1.0)))
|
||||
|
||||
predict_label = fluid.layers.fc(hidden,
|
||||
size=CLASS_NUM,
|
||||
act='softmax')
|
||||
loss = fluid.layers.mean(
|
||||
fluid.layers.cross_entropy(
|
||||
input=predict_label, label=label))
|
||||
|
||||
optimizer = fluid.optimizer.Adam()
|
||||
optimizer.minimize(loss)
|
||||
return startup_prog, main_prog, py_reader, loss
|
||||
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
def run_main(self, use_legacy_py_reader, with_data_parallel, places):
|
||||
with fluid.scope_guard(fluid.Scope()):
|
||||
startup_prog, main_prog, py_reader, loss = simple_fc_net(
|
||||
places, use_legacy_py_reader)
|
||||
exe = fluid.Executor(place=places[0])
|
||||
exe.run(startup_prog)
|
||||
|
||||
prog = fluid.CompiledProgram(main_prog)
|
||||
if with_data_parallel:
|
||||
prog = prog.with_data_parallel(
|
||||
loss_name=loss.name, places=places)
|
||||
|
||||
step = 0
|
||||
start_t = time.time()
|
||||
if use_legacy_py_reader:
|
||||
for _ in six.moves.range(EPOCH_NUM):
|
||||
py_reader.start()
|
||||
while True:
|
||||
try:
|
||||
L, = exe.run(program=prog, fetch_list=[loss])
|
||||
step += 1
|
||||
except fluid.core.EOFException:
|
||||
py_reader.reset()
|
||||
break
|
||||
else:
|
||||
for _ in six.moves.range(EPOCH_NUM):
|
||||
for d in py_reader():
|
||||
'''
|
||||
assert len(d) == len(places)
|
||||
for i, item in enumerate(d):
|
||||
image = item['image']
|
||||
label = item['label']
|
||||
assert image.shape() == [BATCH_SIZE, 784]
|
||||
assert label.shape() == [BATCH_SIZE, 1]
|
||||
assert image._place()._equals(places[i])
|
||||
assert label._place()._equals(places[i])
|
||||
'''
|
||||
L, = exe.run(program=prog, feed=d, fetch_list=[loss])
|
||||
step += 1
|
||||
end_t = time.time()
|
||||
return {"time": end_t - start_t, "step": step}
|
||||
|
||||
def prepare_places(self, with_data_parallel):
|
||||
places = [[fluid.CPUPlace()], ]
|
||||
if with_data_parallel:
|
||||
places.append([fluid.CPUPlace()] * 2)
|
||||
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
tmp = fluid.cuda_places()
|
||||
assert len(tmp) > 0, "no gpu detected"
|
||||
if with_data_parallel:
|
||||
places.append(tmp)
|
||||
places.append([tmp[0]])
|
||||
return places
|
||||
|
||||
def test_main(self):
|
||||
for with_data_parallel in [True, False]:
|
||||
for p in self.prepare_places(with_data_parallel):
|
||||
t = []
|
||||
for use_legacy_py_reader in [False, True]:
|
||||
ret = self.run_main(
|
||||
use_legacy_py_reader=use_legacy_py_reader,
|
||||
with_data_parallel=with_data_parallel,
|
||||
places=p)
|
||||
ret['legacy'] = use_legacy_py_reader
|
||||
ret['data_parallel'] = with_data_parallel
|
||||
ret['places'] = p
|
||||
t.append(ret)
|
||||
|
||||
print(t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue