|
|
|
@ -29,32 +29,50 @@ class ChannelImpl : public paddle::framework::Channel<T> {
|
|
|
|
|
friend void paddle::framework::CloseChannel<T>(Channel<T> *);
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
virtual bool CanSend();
|
|
|
|
|
virtual bool CanReceive();
|
|
|
|
|
virtual bool Send(T *);
|
|
|
|
|
virtual bool Receive(T *);
|
|
|
|
|
virtual size_t Cap() { return cap_; }
|
|
|
|
|
virtual void Lock();
|
|
|
|
|
virtual void Unlock();
|
|
|
|
|
virtual bool IsClosed();
|
|
|
|
|
virtual void Close();
|
|
|
|
|
|
|
|
|
|
ChannelImpl(size_t);
|
|
|
|
|
virtual ~ChannelImpl();
|
|
|
|
|
|
|
|
|
|
virtual void AddToSendQ(const void *referrer, T *data,
|
|
|
|
|
std::shared_ptr<std::condition_variable_any> cond,
|
|
|
|
|
std::function<bool(ChannelAction)> cb);
|
|
|
|
|
virtual void AddToReceiveQ(const void *referrer, T *data,
|
|
|
|
|
std::shared_ptr<std::condition_variable_any> cond,
|
|
|
|
|
std::function<bool(ChannelAction)> cb);
|
|
|
|
|
|
|
|
|
|
virtual void RemoveFromSendQ(const void *referrer);
|
|
|
|
|
virtual void RemoveFromReceiveQ(const void *referrer);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
struct QueueMessage {
|
|
|
|
|
T *data;
|
|
|
|
|
std::condition_variable_any cond;
|
|
|
|
|
std::shared_ptr<std::condition_variable_any> cond;
|
|
|
|
|
bool chan_closed = false;
|
|
|
|
|
bool completed = false;
|
|
|
|
|
const void *referrer; // TODO(thuan): figure out better way to do this
|
|
|
|
|
std::function<bool(ChannelAction)> callback;
|
|
|
|
|
|
|
|
|
|
QueueMessage(T *item) : data(item) {}
|
|
|
|
|
QueueMessage(T *item)
|
|
|
|
|
: data(item), cond(std::make_shared<std::condition_variable_any>()) {}
|
|
|
|
|
|
|
|
|
|
QueueMessage(T *item, std::shared_ptr<std::condition_variable_any> cond)
|
|
|
|
|
: data(item), cond(cond) {}
|
|
|
|
|
|
|
|
|
|
void Wait(std::unique_lock<std::recursive_mutex> &lock) {
|
|
|
|
|
cond.wait(lock, [this]() { return completed; });
|
|
|
|
|
cond->wait(lock, [this]() { return completed; });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Notify() {
|
|
|
|
|
completed = true;
|
|
|
|
|
cond.notify_all();
|
|
|
|
|
cond->notify_all();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -87,6 +105,18 @@ ChannelImpl<T>::ChannelImpl(size_t capacity)
|
|
|
|
|
PADDLE_ENFORCE_GE(capacity, 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool ChannelImpl<T>::CanSend() {
|
|
|
|
|
std::lock_guard<std::recursive_mutex> lock{mu_};
|
|
|
|
|
return !closed_ && (!recvq.empty() || buf_.size() < cap_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool ChannelImpl<T>::CanReceive() {
|
|
|
|
|
std::lock_guard<std::recursive_mutex> lock{mu_};
|
|
|
|
|
return !(closed_ && buf_.empty()) && (!sendq.empty() || buf_.size() > 0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool ChannelImpl<T>::Send(T *item) {
|
|
|
|
|
send_ctr++;
|
|
|
|
@ -105,7 +135,24 @@ bool ChannelImpl<T>::Send(T *item) {
|
|
|
|
|
std::shared_ptr<QueueMessage> m = recvq.front();
|
|
|
|
|
recvq.pop_front();
|
|
|
|
|
// Do the data transfer
|
|
|
|
|
*(m->data) = std::move(*item);
|
|
|
|
|
// We will do this data transfer if either of the following
|
|
|
|
|
// cases are true
|
|
|
|
|
// 1. callback == nullptr // This means it was a regular channel send
|
|
|
|
|
// 2. callback returns true
|
|
|
|
|
bool do_send = true;
|
|
|
|
|
if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND);
|
|
|
|
|
if (do_send)
|
|
|
|
|
*(m->data) = std::move(*item);
|
|
|
|
|
else
|
|
|
|
|
// We cannot do the data transfer because
|
|
|
|
|
// this QueueMessage was added by Select
|
|
|
|
|
// and some other case was executed.
|
|
|
|
|
// So call the Send function again.
|
|
|
|
|
// We do not care about notifying other
|
|
|
|
|
// because they would have been notified
|
|
|
|
|
// by the executed select case.
|
|
|
|
|
return Send(item);
|
|
|
|
|
|
|
|
|
|
// Wake up the blocked process and unlock
|
|
|
|
|
m->Notify();
|
|
|
|
|
lock.unlock();
|
|
|
|
@ -150,7 +197,25 @@ bool ChannelImpl<T>::Receive(T *item) {
|
|
|
|
|
std::shared_ptr<QueueMessage> m = sendq.front();
|
|
|
|
|
sendq.pop_front();
|
|
|
|
|
// Do the data transfer
|
|
|
|
|
*item = std::move(*(m->data));
|
|
|
|
|
// We will do this data transfer if either of the following
|
|
|
|
|
// cases are true
|
|
|
|
|
// 1. callback == nullptr // This means it was a regular channel send
|
|
|
|
|
// 2. callback returns true
|
|
|
|
|
bool do_receive = true;
|
|
|
|
|
if (m->callback != nullptr)
|
|
|
|
|
do_receive = m->callback(ChannelAction::RECEIVE);
|
|
|
|
|
if (do_receive)
|
|
|
|
|
*item = std::move(*(m->data));
|
|
|
|
|
else
|
|
|
|
|
// We cannot do the data transfer because
|
|
|
|
|
// this QueueMessage was added by Select
|
|
|
|
|
// and some other case was executed.
|
|
|
|
|
// So call the Receive function again.
|
|
|
|
|
// We do not care about notifying other
|
|
|
|
|
// because they would have been notified
|
|
|
|
|
// by the executed select case.
|
|
|
|
|
return Receive(item);
|
|
|
|
|
|
|
|
|
|
// Wake up the blocked process and unlock
|
|
|
|
|
m->Notify();
|
|
|
|
|
lock.unlock();
|
|
|
|
@ -186,6 +251,12 @@ void ChannelImpl<T>::Unlock() {
|
|
|
|
|
mu_.unlock();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool ChannelImpl<T>::IsClosed() {
|
|
|
|
|
std::lock_guard<std::recursive_mutex> lock{mu_};
|
|
|
|
|
return closed_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ChannelImpl<T>::Close() {
|
|
|
|
|
std::unique_lock<std::recursive_mutex> lock{mu_};
|
|
|
|
@ -203,6 +274,12 @@ void ChannelImpl<T>::Close() {
|
|
|
|
|
std::shared_ptr<QueueMessage> m = recvq.front();
|
|
|
|
|
recvq.pop_front();
|
|
|
|
|
m->chan_closed = true;
|
|
|
|
|
|
|
|
|
|
// Execute callback function (if any)
|
|
|
|
|
if (m->callback != nullptr) {
|
|
|
|
|
m->callback(ChannelAction::CLOSE);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
m->Notify();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -211,10 +288,72 @@ void ChannelImpl<T>::Close() {
|
|
|
|
|
std::shared_ptr<QueueMessage> m = sendq.front();
|
|
|
|
|
sendq.pop_front();
|
|
|
|
|
m->chan_closed = true;
|
|
|
|
|
|
|
|
|
|
// Execute callback function (if any)
|
|
|
|
|
if (m->callback != nullptr) {
|
|
|
|
|
m->callback(ChannelAction::CLOSE);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
m->Notify();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ChannelImpl<T>::AddToSendQ(
|
|
|
|
|
const void *referrer, T *data,
|
|
|
|
|
std::shared_ptr<std::condition_variable_any> cond,
|
|
|
|
|
std::function<bool(ChannelAction)> cb) {
|
|
|
|
|
std::lock_guard<std::recursive_mutex> lock{mu_};
|
|
|
|
|
auto m = std::make_shared<QueueMessage>(data, cond);
|
|
|
|
|
m->referrer = referrer;
|
|
|
|
|
m->callback = cb;
|
|
|
|
|
sendq.push_back(m);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ChannelImpl<T>::AddToReceiveQ(
|
|
|
|
|
const void *referrer, T *data,
|
|
|
|
|
std::shared_ptr<std::condition_variable_any> cond,
|
|
|
|
|
std::function<bool(ChannelAction)> cb) {
|
|
|
|
|
std::lock_guard<std::recursive_mutex> lock{mu_};
|
|
|
|
|
auto m = std::make_shared<QueueMessage>(data, cond);
|
|
|
|
|
m->referrer = referrer;
|
|
|
|
|
m->callback = cb;
|
|
|
|
|
recvq.push_back(m);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ChannelImpl<T>::RemoveFromSendQ(const void *referrer) {
|
|
|
|
|
std::lock_guard<std::recursive_mutex> lock{mu_};
|
|
|
|
|
|
|
|
|
|
for (auto it = sendq.begin(); it != sendq.end();) {
|
|
|
|
|
std::shared_ptr<QueueMessage> sendMsg = (std::shared_ptr<QueueMessage>)*it;
|
|
|
|
|
|
|
|
|
|
if (sendMsg->referrer == referrer) {
|
|
|
|
|
it = sendq.erase(it);
|
|
|
|
|
send_ctr--;
|
|
|
|
|
} else {
|
|
|
|
|
++it;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ChannelImpl<T>::RemoveFromReceiveQ(const void *referrer) {
|
|
|
|
|
std::lock_guard<std::recursive_mutex> lock{mu_};
|
|
|
|
|
|
|
|
|
|
for (auto it = recvq.begin(); it != recvq.end();) {
|
|
|
|
|
std::shared_ptr<QueueMessage> recvMsg = (std::shared_ptr<QueueMessage>)*it;
|
|
|
|
|
|
|
|
|
|
if (recvMsg->referrer == referrer) {
|
|
|
|
|
it = recvq.erase(it);
|
|
|
|
|
recv_ctr--;
|
|
|
|
|
} else {
|
|
|
|
|
++it;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
ChannelImpl<T>::~ChannelImpl() {
|
|
|
|
|
Close();
|
|
|
|
|