Redesign channel implementation for Select Op (#8814)

* Redesign channel implementation for Select Op

* Remove unecessary header

* Remove unnecessary comments
shanyi15-patch-2
Abhinav Arora 7 years ago committed by GitHub
parent 351795e892
commit 78c884d7a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,24 +28,19 @@ class Channel {
virtual bool Send(T*) = 0;
virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
virtual void Close() = 0;
virtual ~Channel() {}
};
// Forward declaration of channel implementations.
namespace details {
template <typename T>
class Buffered;
template <typename T>
class UnBuffered;
} // namespace details
class ChannelImpl;
template <typename T>
Channel<T>* MakeChannel(size_t buffer_size) {
if (buffer_size > 0) {
return new details::Buffered<T>(buffer_size);
}
return new details::UnBuffered<T>();
return new ChannelImpl<T>(buffer_size);
}
template <typename T>
@ -89,6 +84,19 @@ class ChannelHolder {
if (IsInitialized()) holder_->Close();
}
size_t Cap() {
if (IsInitialized()) return holder_->Cap();
return -1;
}
void Lock() {
if (IsInitialized()) holder_->Lock();
}
void Unlock() {
if (IsInitialized()) holder_->Unlock();
}
inline bool IsInitialized() const { return holder_ != nullptr; }
inline const std::type_index Type() {
@ -106,6 +114,9 @@ class ChannelHolder {
virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0;
virtual void Close() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
virtual size_t Cap() = 0;
};
template <typename T>
@ -115,11 +126,28 @@ class ChannelHolder {
}
virtual const std::type_index Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
virtual void Close() {
if (channel_) channel_->Close();
}
virtual size_t Cap() {
if (channel_)
return channel_->Cap();
else
return -1;
}
virtual void Lock() {
if (channel_) channel_->Lock();
}
virtual void Unlock() {
if (channel_) channel_->Unlock();
}
std::unique_ptr<Channel<T>> channel_;
const std::type_index type_;
};
@ -131,5 +159,4 @@ class ChannelHolder {
} // namespace framework
} // namespace paddle
#include "paddle/fluid/framework/details/buffered_channel.h"
#include "paddle/fluid/framework/details/unbuffered_channel.h"
#include "paddle/fluid/framework/channel_impl.h"

@ -0,0 +1,229 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <atomic>
#include <condition_variable>
#include <deque>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
template <typename T>
class ChannelImpl : public paddle::framework::Channel<T> {
friend Channel<T> *paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T> *);
public:
virtual bool Send(T *);
virtual bool Receive(T *);
virtual size_t Cap() { return cap_; }
virtual void Lock();
virtual void Unlock();
virtual void Close();
ChannelImpl(size_t);
virtual ~ChannelImpl();
private:
struct QueueMessage {
T *data;
std::condition_variable_any cond;
bool chan_closed = false;
bool completed = false;
QueueMessage(T *item) : data(item) {}
void Wait(std::unique_lock<std::recursive_mutex> &lock) {
cond.wait(lock, [this]() { return completed; });
}
void Notify() {
completed = true;
cond.notify_all();
}
};
bool send_return(bool value) {
send_ctr--;
destructor_cond_.notify_all();
return value;
}
bool recv_return(bool value) {
recv_ctr--;
destructor_cond_.notify_all();
return value;
}
size_t cap_;
std::recursive_mutex mu_;
bool closed_;
std::deque<T> buf_;
std::deque<std::shared_ptr<QueueMessage>> recvq;
std::deque<std::shared_ptr<QueueMessage>> sendq;
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
std::condition_variable_any destructor_cond_;
};
template <typename T>
ChannelImpl<T>::ChannelImpl(size_t capacity)
: cap_(capacity), closed_(false), send_ctr(0), recv_ctr(0) {
PADDLE_ENFORCE_GE(capacity, 0);
}
template <typename T>
bool ChannelImpl<T>::Send(T *item) {
send_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed, do nothing
if (closed_) {
lock.unlock();
// TODO(abhinavarora) Should panic on closed channel
return send_return(false);
}
// If there is a receiver, directly pass the value we want
// to send to the receiver, bypassing the channel buffer if any
if (!recvq.empty()) {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
// Do the data transfer
*(m->data) = std::move(*item);
// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
return send_return(true);
}
// Unbuffered channel will always bypass this
// If buffered channel has space in buffer,
// write the element to the buffer.
if (buf_.size() < cap_) {
// Copy to buffer
buf_.push_back(std::move(*item));
// Release lock and return true
lock.unlock();
return send_return(true);
}
// Block on channel, because some receiver will complete
// the operation for us
auto m = std::make_shared<QueueMessage>(item);
sendq.push_back(m);
m->Wait(lock);
// TODO(abhinavarora) Should panic on closed channel
return send_return(!m->chan_closed);
}
template <typename T>
bool ChannelImpl<T>::Receive(T *item) {
recv_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed and buffer is empty or
// channel is unbuffered
if (closed_ && buf_.empty()) {
lock.unlock();
return recv_return(false);
}
// If there is a sender, directly receive the value we want
// from the sender, bypassing the channel buffer if any
if (!sendq.empty()) {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
// Do the data transfer
*item = std::move(*(m->data));
// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
return recv_return(true);
}
// If this is a buffered channel and there are items in buffer
if (buf_.size() > 0) {
// Directly read from buffer
*item = std::move(buf_.front());
buf_.pop_front();
// Release lock and return true
lock.unlock();
return recv_return(true);
}
// No sender available, block on this channel
// Some receiver will complete the option for us
auto m = std::make_shared<QueueMessage>(item);
recvq.push_back(m);
m->Wait(lock);
return recv_return(!m->chan_closed);
}
template <typename T>
void ChannelImpl<T>::Lock() {
mu_.lock();
}
template <typename T>
void ChannelImpl<T>::Unlock() {
mu_.unlock();
}
template <typename T>
void ChannelImpl<T>::Close() {
std::unique_lock<std::recursive_mutex> lock{mu_};
if (closed_) {
// TODO(abhinavarora): closing an already closed channel should panic
lock.unlock();
return;
}
closed_ = true;
// Empty the readers
while (!recvq.empty()) {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
m->chan_closed = true;
m->Notify();
}
// Empty the senders
while (!sendq.empty()) {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
m->chan_closed = true;
m->Notify();
}
}
template <typename T>
ChannelImpl<T>::~ChannelImpl() {
Close();
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
std::unique_lock<std::recursive_mutex> lock{mu_};
destructor_cond_.wait(lock,
[this]() { return send_ctr == 0 && recv_ctr == 0; });
}
} // namespace framework
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -1,142 +0,0 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <condition_variable>
#include <deque>
#include <mutex>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace details {
// Four of the properties of Buffered Channel:
// - A send to a full channel blocks temporarily until a receive from the
// channel or the channel is closed.
// - A receive from an empty channel blocks temporarily until a send to the
// channel or the channel is closed.
// - A send to a closed channel returns false immediately.
// - A receive from a closed channel returns false immediately.
template <typename T>
class Buffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual bool Send(T*);
virtual bool Receive(T*);
virtual size_t Cap() { return cap_; }
virtual void Close();
virtual ~Buffered();
private:
size_t cap_;
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::condition_variable destructor_cond_var_;
std::deque<T> channel_;
std::atomic<bool> closed_{false};
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
Buffered(size_t cap) : cap_(cap), closed_(false) {
PADDLE_ENFORCE_GT(cap, 0);
}
void NotifyAllParticipants(std::unique_lock<std::mutex>*);
};
template <typename T>
bool Buffered<T>::Send(T* item) {
bool ret = false;
if (closed_) {
return ret;
}
send_ctr++;
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; });
if (!closed_) {
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
ret = true;
}
send_ctr--;
destructor_cond_var_.notify_one();
return ret;
}
template <typename T>
bool Buffered<T>::Receive(T* item) {
bool ret = false;
// Once the channel has been closed and all data has been consumed,
// just return false. Don't even try acquiring the mutex.
if (closed_ && channel_.empty()) {
return false;
}
recv_ctr++;
std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
if (!channel_.empty()) {
*item = std::move(channel_.front());
channel_.pop_front();
full_cond_var_.notify_one();
ret = true;
}
recv_ctr--;
destructor_cond_var_.notify_one();
return ret;
}
template <typename T>
void Buffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
NotifyAllParticipants(&lock);
}
template <typename T>
Buffered<T>::~Buffered() {
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
channel_.clear();
NotifyAllParticipants(&lock);
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
lock.lock();
destructor_cond_var_.wait(
lock, [this]() { return send_ctr == 0 && recv_ctr == 0; });
}
template <typename T>
void Buffered<T>::NotifyAllParticipants(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_all();
empty_cond_var_.notify_all();
}
} // namespace details
} // namespace framework
} // namespace paddle

@ -1,174 +0,0 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <condition_variable>
#include <mutex>
#include "paddle/fluid/framework/channel.h"
namespace paddle {
namespace framework {
namespace details {
// Four of the properties of UnBuffered Channel:
// - A send to a channel blocks temporarily until a receive from the
// channel or the channel is closed.
// - A receive from a channel blocks temporarily until a send to the
// channel or the channel is closed.
// - A send to a closed channel returns false immediately.
// - A receive from a closed channel returns false immediately.
template <typename T>
class UnBuffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual bool Send(T*);
virtual bool Receive(T*);
virtual size_t Cap() { return 0; }
virtual void Close();
virtual ~UnBuffered();
private:
std::mutex mu_ch_;
// Mutex for readers and writers who are waiting for other reader
// and writer to complete execution
std::recursive_mutex mu_read_, mu_write_;
// reader_found_ is set true when a reader is ready to accept data
// writer_found_ is set true when a writer is ready to send data
// A transaction occurs only when both are true
std::atomic<bool> reader_found_{false}, writer_found_{false};
std::condition_variable cv_channel_;
std::condition_variable_any cv_reader_, cv_writer_, cv_destructor_;
T* item{nullptr};
std::atomic<bool> closed_{false};
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
UnBuffered() : closed_(false) {}
void NotifyAllParticipants(std::unique_lock<std::mutex>*);
};
// This function implements the concept of how data should
// be sent from a writer to a reader.
template <typename T>
bool UnBuffered<T>::Send(T* data) {
bool ret = false;
if (closed_) {
return ret;
}
send_ctr++;
// Prevent other writers from entering
std::unique_lock<std::recursive_mutex> writer_lock(mu_write_);
writer_found_ = true;
std::unique_lock<std::recursive_mutex> cv_lock(mu_write_);
// If writer comes first, it should wait till a reader arrives
cv_writer_.wait(cv_lock,
[this]() { return reader_found_ == true || closed_; });
cv_reader_.notify_one();
if (!closed_) {
std::unique_lock<std::mutex> channel_lock(mu_ch_);
item = data;
channel_lock.unlock();
cv_channel_.notify_one();
channel_lock.lock();
cv_channel_.wait(channel_lock,
[this]() { return item == nullptr || closed_; });
ret = true;
}
writer_found_ = false;
send_ctr--;
cv_destructor_.notify_one();
return ret;
}
// This function implements the concept of how
// data that was sent by a writer is read from a reader.
template <typename T>
bool UnBuffered<T>::Receive(T* data) {
bool ret = false;
// If channel is closed, we don't even want any reader to enter.
// Unlike a buffered channel, an unbuffered channel does not allow
// readers to read after closing because there is no buffer to be consumed.
if (closed_) return ret;
recv_ctr++;
// Prevent other readers from entering
std::unique_lock<std::recursive_mutex> read_lock{mu_read_};
reader_found_ = true;
std::unique_lock<std::recursive_mutex> cv_lock{mu_read_};
// If reader comes first, it should wait till a writer arrives
cv_reader_.wait(cv_lock,
[this]() { return writer_found_ == true || closed_; });
cv_writer_.notify_one();
if (!closed_) {
std::unique_lock<std::mutex> lock_ch{mu_ch_};
// Reader should wait for the writer to first write its data
cv_channel_.wait(lock_ch, [this]() { return item != nullptr || closed_; });
if (!closed_) {
*data = std::move(*item);
item = nullptr;
lock_ch.unlock();
ret = true;
}
cv_channel_.notify_one();
}
reader_found_ = false;
recv_ctr--;
cv_destructor_.notify_one();
return ret;
}
// This function implements the sequence of events
// that take place once the channel is closed.
template <typename T>
void UnBuffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_ch_);
item = nullptr;
closed_ = true;
NotifyAllParticipants(&lock);
}
// This function implements the sequence of events
// that are executed once the object of an UnBuffered
// channel is destroyed.
template <typename T>
UnBuffered<T>::~UnBuffered() {
std::unique_lock<std::mutex> lock(mu_ch_);
item = nullptr;
closed_ = true;
NotifyAllParticipants(&lock);
lock.lock();
cv_destructor_.wait(lock,
[this]() { return send_ctr == 0 && recv_ctr == 0; });
}
// This function notifies all the readers, writers and
// the channel condition variables.
template <typename T>
void UnBuffered<T>::NotifyAllParticipants(std::unique_lock<std::mutex>* lock) {
lock->unlock();
cv_writer_.notify_all();
cv_channel_.notify_all();
cv_reader_.notify_all();
}
} // namespace details
} // namespace framework
} // namespace paddle
Loading…
Cancel
Save