You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/channel.h

462 lines
11 KiB

// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include <glog/logging.h>
#include <algorithm>
#include <condition_variable> // NOLINT
#include <deque>
#include <limits>
#include <memory>
#include <mutex> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/framework/expect.h"
namespace paddle {
namespace framework {
template <class T>
class ChannelObject {
public:
ChannelObject() {}
// capacity can be zero
explicit ChannelObject(size_t capacity) {
capacity_ = (std::min)(MaxCapacity(), capacity);
}
const std::deque<T>& GetData() const { return data_; }
void Clear() {
std::unique_lock<std::mutex> lock(mutex_);
data_.clear();
data_.shrink_to_fit();
}
size_t Capacity() {
return capacity_; // atomic
}
void SetCapacity(size_t x) { // capacity can be zero
std::lock_guard<std::mutex> lock(mutex_);
capacity_ = std::min(MaxCapacity(), x);
Notify();
}
size_t BlockSize() {
return block_size_; // atomic
}
void SetBlockSize(size_t x) {
CHECK(x >= 1) << "block size must be >= 1";
std::lock_guard<std::mutex> lock(mutex_);
block_size_ = x;
}
template <class U>
void InheritFrom(const std::shared_ptr<ChannelObject<U>>& other) {
std::lock_guard<std::mutex> lock(mutex_);
capacity_ = other->Capacity();
block_size_ = other->BlockSize();
}
bool Closed() {
return closed_; // atomic
}
// open channel, then data can be write() to channel
void Open() {
std::lock_guard<std::mutex> lock(mutex_);
closed_ = false;
Notify();
}
// close channel, then no more data can be write() to channel
void Close() {
std::lock_guard<std::mutex> lock(mutex_);
closed_ = true;
Notify();
}
size_t Size() {
std::lock_guard<std::mutex> lock(mutex_);
return data_.size();
}
bool Empty() {
std::lock_guard<std::mutex> lock(mutex_);
return EmptyUnlocked();
}
// blocking operation
bool Get(T& val) { return Read(1, &val) != 0; } // NOLINT
// blocking operation
// returns 0 if the channel is closed and empty
size_t Read(size_t n, T* p) {
if (n == 0) {
return 0;
}
std::unique_lock<std::mutex> lock(mutex_);
size_t finished = Read(n, p, lock);
Notify();
return finished;
}
// blocking operation
bool Put(T&& val) { return WriteMove(1, &val) != 0; }
// blocking operation
bool Put(const T& val) { return Write(1, &val) != 0; }
// blocking operation
// returns value less than n if the channel is closed
size_t Write(size_t n, const T* p) {
if (n == 0) {
return 0;
}
std::unique_lock<std::mutex> lock(mutex_);
size_t finished = Write(n, p, lock);
Notify();
return finished;
}
// WriteMove() will clear original contents of input array
size_t WriteMove(size_t n, T* p) {
if (n == 0) {
return 0;
}
std::unique_lock<std::mutex> lock(mutex_);
size_t finished = WriteMove(n, p, lock);
Notify();
return finished;
}
// read data of block size from channel to vector
size_t Read(std::vector<T>& p) { // NOLINT
p.resize(block_size_);
size_t finished = Read(p.size(), &p[0]);
p.resize(finished);
return finished;
}
size_t ReadAll(std::vector<T>& p) { // NOLINT
p.clear();
size_t finished = 0;
size_t n = 0;
do {
// _block_size may change anytime
n = block_size_;
p.resize(finished + n);
n = Read(n, &p[finished]);
finished += n;
} while (n != 0);
p.resize(finished);
return finished;
}
// write data from vector to channel
size_t Write(const std::vector<T>& p) { return Write(p.size(), &p[0]); }
// write data from vector to channel
size_t Write(std::vector<T>&& p) { return WriteMove(p.size(), &p[0]); }
private:
size_t capacity_ = MaxCapacity();
size_t block_size_ = 1024;
bool closed_ = false;
std::mutex mutex_;
// use deque to store data
std::deque<T> data_;
size_t reading_count_ = 0;
int empty_waiters_ = 0;
int full_waiters_ = 0;
std::condition_variable empty_cond_;
std::condition_variable full_cond_;
static constexpr size_t MaxCapacity() {
return (std::numeric_limits<size_t>::max)() / 2;
}
void Notify() {
if (empty_waiters_ != 0 && (!EmptyUnlocked() || closed_)) {
empty_cond_.notify_one();
}
if (full_waiters_ != 0 && (!FullUnlocked() || closed_)) {
full_cond_.notify_one();
}
}
bool EmptyUnlocked() { return data_.empty(); }
bool FullUnlocked() { return data_.size() >= capacity_ + reading_count_; }
bool WaitForRead(std::unique_lock<std::mutex>& lock) { // NOLINT
#ifdef _LINUX
while (unlikely(EmptyUnlocked() && !closed_)) {
#else
while (EmptyUnlocked() && !closed_) {
#endif
if (full_waiters_ != 0) {
full_cond_.notify_one();
}
empty_waiters_++;
empty_cond_.wait(lock);
empty_waiters_--;
}
return !EmptyUnlocked();
}
bool WaitForWrite(std::unique_lock<std::mutex>& lock) { // NOLINT
#ifdef _LINUX
while (unlikely(FullUnlocked() && !closed_)) {
#else
while (FullUnlocked() && !closed_) {
#endif
if (empty_waiters_ != 0) {
empty_cond_.notify_one();
}
full_waiters_++;
full_cond_.wait(lock);
full_waiters_--;
}
return !closed_;
}
size_t Read(size_t n, T* p, std::unique_lock<std::mutex>& lock) { // NOLINT
size_t finished = 0;
CHECK(n <= MaxCapacity() - reading_count_);
reading_count_ += n;
while (finished < n && WaitForRead(lock)) {
size_t m = std::min(n - finished, data_.size());
for (size_t i = 0; i < m; i++) {
p[finished++] = std::move(data_.front());
data_.pop_front();
}
reading_count_ -= m;
}
reading_count_ -= n - finished;
return finished;
}
size_t Write(size_t n,
const T* p, // NOLINT
std::unique_lock<std::mutex>& lock) { // NOLINT
size_t finished = 0;
while (finished < n && WaitForWrite(lock)) {
size_t m =
std::min(n - finished, capacity_ + reading_count_ - data_.size());
for (size_t i = 0; i < m; i++) {
data_.push_back(p[finished++]);
}
}
return finished;
}
size_t WriteMove(size_t n,
T* p, // NOLINT
std::unique_lock<std::mutex>& lock) { // NOLINT
size_t finished = 0;
while (finished < n && WaitForWrite(lock)) {
size_t m =
std::min(n - finished, capacity_ + reading_count_ - data_.size());
for (size_t i = 0; i < m; i++) {
data_.push_back(std::move(p[finished++]));
}
}
return finished;
}
}; // NOLINT
template <class T>
using Channel = std::shared_ptr<ChannelObject<T>>;
template <class T>
Channel<T> MakeChannel(size_t capacity = (std::numeric_limits<size_t>::max)()) {
return std::make_shared<ChannelObject<T>>(capacity);
}
template <class T, class U>
Channel<T> MakeChannel(const Channel<U>& other) {
CHECK(other != nullptr) << "channel can not be NULL";
Channel<T> chan = std::make_shared<ChannelObject<T>>();
chan->InheritFrom(other);
return chan;
}
// NOTE: ChannelReader is a wrapper for quick read channel with a buffer. It
// will read a block data from channel, but user can get data one by one. So it
// is important to notice that user must call operator>> until false, or call
// get_buffer_remain until false to make sure the buffered data all readed.
template <class T>
class ChannelReader {
public:
explicit ChannelReader(ChannelObject<T>* channel = nullptr) {
Reset(channel);
}
~ChannelReader() { CHECK(cursor_ == 0) << "Forgot to read buffer data"; }
ChannelObject<T>* channel() { return channel_; }
void Reset(ChannelObject<T>* channel) {
CHECK(channel != nullptr) << "Channel can not be nullptr";
channel_ = channel;
cursor_ = 0;
failed_ = !channel;
}
// whether there were read failed
operator bool() { return !failed_; }
ChannelReader<T>& operator>>(T& val) {
if (failed_) {
return *this;
}
if (cursor_ >= buffer_.size()) {
cursor_ = 0;
if (channel_->read(buffer_) == 0) {
failed_ = true;
return *this;
}
}
val = std::move(buffer_[cursor_++]);
return *this;
}
bool GetBufferRemain(T& val) { // NOLINT
if (cursor_ >= buffer_.size()) {
cursor_ = 0;
return false;
}
val = std::move(buffer_[cursor_++]);
return true;
}
private:
ChannelObject<T>* channel_ = nullptr;
std::vector<T> buffer_;
size_t cursor_ = 0;
bool failed_ = true;
}; // NOLINT
template <class T>
class ChannelWriter {
public:
explicit ChannelWriter(ChannelObject<T>* channel = nullptr) {
Reset(channel);
}
~ChannelWriter() { CHECK(buffer_.empty()) << "Forgot to flush"; }
ChannelObject<T>* channel() { return channel_; }
void Reset(ChannelObject<T>* channel) {
CHECK(buffer_.empty()) << "Forgot to flush";
// CHECK(channel != nullptr) << "Channel can not be nullptr";
channel_ = channel;
buffer_.clear();
failed_ = !channel;
}
// whether there were write failed
operator bool() { return !failed_; }
ChannelWriter<T>& operator<<(T&& val) {
if (failed_) {
return *this;
}
buffer_.push_back(std::move(val));
if (buffer_.size() >= channel_->BlockSize()) {
Flush();
}
return *this;
}
ChannelWriter<T>& operator<<(const T& val) {
if (failed_) {
return *this;
}
buffer_.push_back(val);
if (buffer_.size() >= channel_->BlockSize()) {
Flush();
}
return *this;
}
void Flush() {
if (failed_ || buffer_.empty()) {
buffer_.clear();
return;
}
failed_ |=
channel_->WriteMove(buffer_.size(), &buffer_[0]) != buffer_.size();
buffer_.clear();
}
private:
ChannelObject<T>* channel_ = nullptr;
std::vector<T> buffer_;
bool failed_ = true;
}; // NOLINT
// only used for range-for loop
// for (auto& x : chan) {...}
template <class T>
struct ChannelIterator {
std::shared_ptr<ChannelReader<T>> reader_;
T data_;
void operator++() {
CHECK(reader_ != nullptr) << "reader can not be NULL";
if (!(*reader_ >> data_)) {
reader_ = nullptr;
}
}
T& operator*() { return data_; }
friend bool operator==(const ChannelIterator<T>& a,
const ChannelIterator<T>& b) {
return a.reader_ == b.reader_;
}
friend bool operator!=(const ChannelIterator<T>& a,
const ChannelIterator<T>& b) {
return a.reader_ != b.reader_;
}
}; // NOLINT
template <class T>
ChannelIterator<T> begin(ChannelObject<T>* chan) {
ChannelIterator<T> it{std::make_shared<ChannelReader<T>>(chan), T()};
++it;
return it;
}
template <class T>
ChannelIterator<T> end(ChannelObject<T>* chan) {
return {nullptr, T()};
}
} // namespace framework
} // namespace paddle