|
|
|
@ -15,7 +15,7 @@ limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <stddef.h> // for size_t
|
|
|
|
|
#include <atomic>
|
|
|
|
|
#include <condition_variable>
|
|
|
|
|
#include <condition_variable> // NOLINT
|
|
|
|
|
#include <deque>
|
|
|
|
|
#include "paddle/fluid/framework/channel.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
@ -38,7 +38,7 @@ class ChannelImpl : public paddle::framework::Channel<T> {
|
|
|
|
|
virtual void Unlock();
|
|
|
|
|
virtual bool IsClosed();
|
|
|
|
|
virtual void Close();
|
|
|
|
|
ChannelImpl(size_t);
|
|
|
|
|
explicit ChannelImpl(size_t);
|
|
|
|
|
virtual ~ChannelImpl();
|
|
|
|
|
|
|
|
|
|
virtual void AddToSendQ(const void *referrer, T *data,
|
|
|
|
@ -60,7 +60,7 @@ class ChannelImpl : public paddle::framework::Channel<T> {
|
|
|
|
|
const void *referrer; // TODO(thuan): figure out better way to do this
|
|
|
|
|
std::function<bool(ChannelAction)> callback;
|
|
|
|
|
|
|
|
|
|
QueueMessage(T *item)
|
|
|
|
|
explicit QueueMessage(T *item)
|
|
|
|
|
: data(item), cond(std::make_shared<std::condition_variable_any>()) {}
|
|
|
|
|
|
|
|
|
|
QueueMessage(T *item, std::shared_ptr<std::condition_variable_any> cond)
|
|
|
|
@ -88,15 +88,15 @@ class ChannelImpl : public paddle::framework::Channel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<QueueMessage> get_first_message(
|
|
|
|
|
std::deque<std::shared_ptr<QueueMessage>> &queue, ChannelAction action) {
|
|
|
|
|
while (!queue.empty()) {
|
|
|
|
|
std::deque<std::shared_ptr<QueueMessage>> *queue, ChannelAction action) {
|
|
|
|
|
while (!queue->empty()) {
|
|
|
|
|
// Check whether this message was added by Select
|
|
|
|
|
// If this was added by Select then execute the callback
|
|
|
|
|
// to check if you can execute this message. The callback
|
|
|
|
|
// can return false if some other case was executed in Select.
|
|
|
|
|
// In that case just discard this QueueMessage and process next.
|
|
|
|
|
std::shared_ptr<QueueMessage> m = queue.front();
|
|
|
|
|
queue.pop_front();
|
|
|
|
|
std::shared_ptr<QueueMessage> m = queue->front();
|
|
|
|
|
queue->pop_front();
|
|
|
|
|
if (m->callback == nullptr || m->callback(action)) return m;
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
@ -147,7 +147,7 @@ void ChannelImpl<T>::Send(T *item) {
|
|
|
|
|
// to send to the receiver, bypassing the channel buffer if any
|
|
|
|
|
if (!recvq.empty()) {
|
|
|
|
|
std::shared_ptr<QueueMessage> m =
|
|
|
|
|
get_first_message(recvq, ChannelAction::SEND);
|
|
|
|
|
get_first_message(&recvq, ChannelAction::SEND);
|
|
|
|
|
|
|
|
|
|
if (m != nullptr) {
|
|
|
|
|
*(m->data) = std::move(*item);
|
|
|
|
@ -198,7 +198,7 @@ bool ChannelImpl<T>::Receive(T *item) {
|
|
|
|
|
// buffer and move front of send queue to the buffer
|
|
|
|
|
if (!sendq.empty()) {
|
|
|
|
|
std::shared_ptr<QueueMessage> m =
|
|
|
|
|
get_first_message(sendq, ChannelAction::RECEIVE);
|
|
|
|
|
get_first_message(&sendq, ChannelAction::RECEIVE);
|
|
|
|
|
if (buf_.size() > 0) {
|
|
|
|
|
// Case 1 : Channel is Buffered
|
|
|
|
|
// Do Data transfer from front of buffer
|
|
|
|
@ -219,8 +219,9 @@ bool ChannelImpl<T>::Receive(T *item) {
|
|
|
|
|
if (m != nullptr) {
|
|
|
|
|
*item = std::move(*(m->data));
|
|
|
|
|
m->Notify();
|
|
|
|
|
} else
|
|
|
|
|
} else {
|
|
|
|
|
return recv_return(Receive(item));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return recv_return(true);
|
|
|
|
|
}
|
|
|
|
|