parent
2c12af76c5
commit
697ba4b13d
@ -0,0 +1,80 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/reader/py_array_feed_queue.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace reader {
|
||||||
|
|
||||||
|
class PyArrayReader : public framework::ReaderBase {
|
||||||
|
public:
|
||||||
|
explicit PyArrayReader(const std::shared_ptr<PyArrayFeedQueue>& queue) {
|
||||||
|
PADDLE_ENFORCE(queue != nullptr, "PyArrayFeedQueue must not be null");
|
||||||
|
queue_ = queue;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadNext(std::vector<framework::LoDTensor>* out) override {
|
||||||
|
*out = queue_->Dequeue();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReInit() override {
|
||||||
|
// PADDLE_THROW("PyArrayReader does not support ReInit()");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<PyArrayFeedQueue> queue_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CreatePyArrayReaderOp : public framework::OperatorBase {
|
||||||
|
public:
|
||||||
|
using framework::OperatorBase::OperatorBase;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void RunImpl(const framework::Scope& scope,
|
||||||
|
const platform::Place& dev_place) const override {
|
||||||
|
const std::string& feeder_name = Attr<std::string>("feeder_name");
|
||||||
|
auto* feeder_holder_var = scope.FindVar(feeder_name);
|
||||||
|
PADDLE_ENFORCE(feeder_holder_var != nullptr,
|
||||||
|
"No PyArrayFeedQueue variable with name %s found",
|
||||||
|
feeder_name);
|
||||||
|
auto* feeder_holder =
|
||||||
|
feeder_holder_var->template GetMutable<PyArrayFeedQueueHolder>();
|
||||||
|
auto* out = scope.FindVar(Output("Out"))
|
||||||
|
->template GetMutable<framework::ReaderHolder>();
|
||||||
|
out->Reset(new PyArrayReader(feeder_holder->GetFeeder()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CreatePyArrayReaderOpMaker : public FileReaderMakerBase {
|
||||||
|
protected:
|
||||||
|
void Apply() override {
|
||||||
|
AddAttr<std::string>("feeder_name",
|
||||||
|
"Name of the `PyArrayFeedQueueHolder` variable");
|
||||||
|
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Create PyArrayReader to accept Python data feeding.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace reader
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace reader = ::paddle::operators::reader;
|
||||||
|
|
||||||
|
REGISTER_FILE_READER_OPERATOR(create_py_array_reader,
|
||||||
|
reader::CreatePyArrayReaderOp,
|
||||||
|
reader::CreatePyArrayReaderOpMaker);
|
@ -0,0 +1,207 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <condition_variable> //NOLINT
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex> // NOLINT
|
||||||
|
#include <vector>
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "paddle/fluid/framework/lod_tensor.h"
|
||||||
|
#include "paddle/fluid/operators/reader/py_blocking_queue.h"
|
||||||
|
#include "paddle/fluid/operators/reader/reader_op_registry.h"
|
||||||
|
#include "paddle/fluid/pybind/tensor_py.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace reader {
|
||||||
|
|
||||||
|
using PyTuple = ::pybind11::tuple;
|
||||||
|
using PyArray = ::pybind11::array;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using PyArrayT = ::pybind11::array_t<T, ::pybind11::array::c_style |
|
||||||
|
::pybind11::array::forcecast>;
|
||||||
|
|
||||||
|
class PyArrayToTensorVisitor : public boost::static_visitor<void> {
|
||||||
|
public:
|
||||||
|
#define PY_ARRAY_TO_TENSOR_WITH_TYPE(dtype, func_name) \
|
||||||
|
pybind::func_name(tensor_, static_cast<const PyArrayT<dtype>&>(py_array_), \
|
||||||
|
place)
|
||||||
|
|
||||||
|
#define PY_ARRAY_TO_TENSOR(func_name) \
|
||||||
|
if (IsType<size_t>()) { \
|
||||||
|
PY_ARRAY_TO_TENSOR_WITH_TYPE(size_t, func_name); \
|
||||||
|
} else if (IsType<int64_t>()) { \
|
||||||
|
PY_ARRAY_TO_TENSOR_WITH_TYPE(int64_t, func_name); \
|
||||||
|
} else if (IsType<int32_t>()) { \
|
||||||
|
PY_ARRAY_TO_TENSOR_WITH_TYPE(int32_t, func_name); \
|
||||||
|
} else if (IsType<int16_t>()) { \
|
||||||
|
PY_ARRAY_TO_TENSOR_WITH_TYPE(int16_t, func_name); \
|
||||||
|
} else if (IsType<uint8_t>()) { \
|
||||||
|
PY_ARRAY_TO_TENSOR_WITH_TYPE(uint8_t, func_name); \
|
||||||
|
} else if (IsType<float>()) { \
|
||||||
|
PY_ARRAY_TO_TENSOR_WITH_TYPE(float, func_name); \
|
||||||
|
} else if (IsType<double>()) { \
|
||||||
|
PY_ARRAY_TO_TENSOR_WITH_TYPE(double, func_name); \
|
||||||
|
} else { \
|
||||||
|
PADDLE_THROW("unsupported dtype of python array"); \
|
||||||
|
}
|
||||||
|
|
||||||
|
PyArrayToTensorVisitor(const PyArray& py_array, framework::Tensor* tensor)
|
||||||
|
: py_array_(py_array), tensor_(tensor) {}
|
||||||
|
|
||||||
|
void operator()(const platform::CPUPlace& place) {
|
||||||
|
PY_ARRAY_TO_TENSOR(PyCPUTensorSetFromArray);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(const platform::CUDAPlace& place) {
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
PY_ARRAY_TO_TENSOR(PyCUDATensorSetFromArray);
|
||||||
|
#else
|
||||||
|
PADDLE_THROW("CUDAPlace is not supported in CPU only version");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(const platform::CUDAPinnedPlace& place) {
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
PY_ARRAY_TO_TENSOR(PyCUDAPinnedTensorSetFromArray);
|
||||||
|
#else
|
||||||
|
PADDLE_THROW("CUDAPinnedPlace is not supported in CPU only version");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef PY_ARRAY_TO_TENSOR
|
||||||
|
#undef PY_ARRAY_TO_TENSOR_WITH_TYPE
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
inline bool IsType() const {
|
||||||
|
return ::pybind11::isinstance<PyArrayT<T>>(py_array_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const PyArray& py_array_;
|
||||||
|
framework::Tensor* tensor_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PyArrayFeedQueueHolder;
|
||||||
|
|
||||||
|
// PyArrayFeedQueue must be thread-safe
|
||||||
|
class PyArrayFeedQueue {
|
||||||
|
friend class PyArrayFeedQueueHolder;
|
||||||
|
|
||||||
|
private:
|
||||||
|
PyArrayFeedQueue(size_t capacity, const std::vector<framework::DDim>& dims,
|
||||||
|
const platform::Place& place)
|
||||||
|
: dims_(dims), place_(place) {
|
||||||
|
queue_.reset(
|
||||||
|
new PyBlockingQueue<std::vector<framework::LoDTensor>>(capacity));
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
~PyArrayFeedQueue() { Close(); }
|
||||||
|
|
||||||
|
bool Enqueue(const std::vector<PyArray>& py_array_vec) {
|
||||||
|
auto lod_tensor_vec = PyArrayVecToLoDTensorVec(py_array_vec);
|
||||||
|
VLOG(5) << "Enqueue at address " << reinterpret_cast<void*>(this);
|
||||||
|
return queue_->Send(std::move(lod_tensor_vec));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Enqueue(const std::vector<framework::LoDTensor>& tensor_vec) {
|
||||||
|
VLOG(5) << "Enqueue at address " << reinterpret_cast<void*>(this);
|
||||||
|
return queue_->Send(tensor_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<framework::LoDTensor> Dequeue() {
|
||||||
|
VLOG(5) << "Dequeue at address " << reinterpret_cast<void*>(this);
|
||||||
|
std::vector<framework::LoDTensor> ret;
|
||||||
|
return queue_->Receive(&ret) ? ret : std::vector<framework::LoDTensor>();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline size_t Size() const { return queue_->Size(); }
|
||||||
|
|
||||||
|
inline size_t Cap() const { return queue_->Cap(); }
|
||||||
|
|
||||||
|
inline bool IsClosed() const { return queue_->IsClosed(); }
|
||||||
|
|
||||||
|
inline void Close() { queue_->Close(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<framework::LoDTensor> PyArrayVecToLoDTensorVec(
|
||||||
|
const std::vector<PyArray>& py_array_vec) {
|
||||||
|
PADDLE_ENFORCE(dims_.size() == py_array_vec.size(),
|
||||||
|
"expected input tensor number %d but found %d", dims_.size(),
|
||||||
|
py_array_vec.size());
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
if (py_array_vec.size() > 1) {
|
||||||
|
size_t dim0 = py_array_vec[0].shape()[0];
|
||||||
|
for (size_t j = 1; j < py_array_vec.size(); ++j) {
|
||||||
|
PADDLE_ENFORCE(dim0 == py_array_vec[j].shape()[0],
|
||||||
|
"0-dim of the %d-th input tensor is %d, but 0-dim of "
|
||||||
|
"the 0-th input tensor is %d",
|
||||||
|
j, py_array_vec[j].shape()[0], dim0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<framework::LoDTensor> lod_tensor_vec;
|
||||||
|
lod_tensor_vec.reserve(py_array_vec.size());
|
||||||
|
|
||||||
|
std::for_each(
|
||||||
|
py_array_vec.begin(), py_array_vec.end(), [&](const PyArray& py_array) {
|
||||||
|
for (int64_t j = 1; j < dims_[i].size(); ++j) {
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
dims_[i][j] == static_cast<int64_t>(py_array.shape()[j]),
|
||||||
|
"expected %d-dim of %d-th input tensor is %d but found %d", j,
|
||||||
|
i, dims_[i][j], py_array.shape()[j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
lod_tensor_vec.emplace_back(framework::LoDTensor());
|
||||||
|
PyArrayToTensorVisitor visitor(py_array, &(lod_tensor_vec.back()));
|
||||||
|
boost::apply_visitor(visitor, place_);
|
||||||
|
++i;
|
||||||
|
});
|
||||||
|
return lod_tensor_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<PyBlockingQueue<std::vector<framework::LoDTensor>>> queue_;
|
||||||
|
std::vector<framework::DDim> dims_;
|
||||||
|
platform::Place place_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class PyArrayFeedQueueHolder {
|
||||||
|
public:
|
||||||
|
PyArrayFeedQueueHolder() {}
|
||||||
|
|
||||||
|
void InitOnce(size_t capacity, const std::vector<framework::DDim>& dims,
|
||||||
|
const platform::Place& place) {
|
||||||
|
PADDLE_ENFORCE(
|
||||||
|
feeder_ == nullptr,
|
||||||
|
"PyArrayFeedQueueHolder::InitOnce() can only be called once");
|
||||||
|
feeder_.reset(new PyArrayFeedQueue(capacity, dims, place));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<PyArrayFeedQueue> GetFeeder() { return feeder_; }
|
||||||
|
const std::shared_ptr<PyArrayFeedQueue>& GetFeeder() const { return feeder_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<PyArrayFeedQueue> feeder_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace reader
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,125 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <condition_variable> // NOLINT
|
||||||
|
#include <deque>
|
||||||
|
|
||||||
|
#include "Python.h"
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace reader {
|
||||||
|
|
||||||
|
// PyBlockingQueue is designed for PyArrayFeedQueue
|
||||||
|
// PyBlockingQueue would release GIL of Python when
|
||||||
|
// the queue is full to avoid deadlock.
|
||||||
|
template <typename T>
|
||||||
|
class PyBlockingQueue {
|
||||||
|
public:
|
||||||
|
explicit PyBlockingQueue(size_t capacity)
|
||||||
|
: capacity_(capacity), closed_(false) {
|
||||||
|
PADDLE_ENFORCE_GT(
|
||||||
|
capacity_, 0,
|
||||||
|
"The capacity of a reader::PyBlockingQueue must be greater than 0.");
|
||||||
|
}
|
||||||
|
|
||||||
|
~PyBlockingQueue() { Close(); }
|
||||||
|
|
||||||
|
bool Send(const T& elem) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
receive_cv_.notify_one();
|
||||||
|
if (queue_.size() >= capacity_ && (!closed_)) {
|
||||||
|
pybind11::gil_scoped_release release;
|
||||||
|
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
|
||||||
|
}
|
||||||
|
if (closed_) {
|
||||||
|
VLOG(5)
|
||||||
|
<< "WARNING: Sending an element to a closed reader::BlockingQueue.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
|
||||||
|
queue_.push_back(elem);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Send(T&& elem) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
receive_cv_.notify_one();
|
||||||
|
if (queue_.size() >= capacity_ && (!closed_)) {
|
||||||
|
pybind11::gil_scoped_release release;
|
||||||
|
send_cv_.wait(lock, [&] { return queue_.size() < capacity_ || closed_; });
|
||||||
|
}
|
||||||
|
if (closed_) {
|
||||||
|
VLOG(5)
|
||||||
|
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
|
||||||
|
queue_.emplace_back(std::move(elem));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Receive(T* elem) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
send_cv_.notify_one();
|
||||||
|
receive_cv_.wait(lock, [&] { return !queue_.empty() || closed_; });
|
||||||
|
if (!queue_.empty()) {
|
||||||
|
PADDLE_ENFORCE_NOT_NULL(elem);
|
||||||
|
*elem = queue_.front();
|
||||||
|
queue_.pop_front();
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
PADDLE_ENFORCE(closed_);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Close() {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
closed_ = true;
|
||||||
|
send_cv_.notify_all();
|
||||||
|
receive_cv_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsClosed() const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return closed_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Cap() const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return capacity_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Size() const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return queue_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t capacity_;
|
||||||
|
bool closed_;
|
||||||
|
std::deque<T> queue_;
|
||||||
|
|
||||||
|
mutable std::mutex mutex_;
|
||||||
|
mutable std::condition_variable receive_cv_;
|
||||||
|
mutable std::condition_variable send_cv_;
|
||||||
|
};
|
||||||
|
} // namespace reader
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
Loading…
Reference in new issue