|
|
|
@ -34,7 +34,7 @@ class Channel {
|
|
|
|
|
public:
|
|
|
|
|
virtual bool CanSend() = 0;
|
|
|
|
|
virtual bool CanReceive() = 0;
|
|
|
|
|
virtual bool Send(T*) = 0;
|
|
|
|
|
virtual void Send(T*) = 0;
|
|
|
|
|
virtual bool Receive(T*) = 0;
|
|
|
|
|
virtual size_t Cap() = 0;
|
|
|
|
|
virtual void Lock() = 0;
|
|
|
|
@ -84,69 +84,81 @@ class ChannelHolder {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool Send(T* data) {
|
|
|
|
|
if (!IsInitialized()) return false;
|
|
|
|
|
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
|
|
|
|
|
void Send(T* data) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
holder_->Type(), std::type_index(typeid(T)),
|
|
|
|
|
"Channel type is not same as the type of the data being sent");
|
|
|
|
|
// Static cast should be safe because we have ensured that types are same
|
|
|
|
|
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
|
|
|
|
|
return channel != nullptr ? channel->Send(data) : false;
|
|
|
|
|
PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
|
|
|
|
|
channel->Send(data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool Receive(T* data) {
|
|
|
|
|
if (!IsInitialized()) return false;
|
|
|
|
|
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
holder_->Type(), std::type_index(typeid(T)),
|
|
|
|
|
"Channel type is not same as the type of the data being sent");
|
|
|
|
|
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
|
|
|
|
|
return channel != nullptr ? channel->Receive(data) : false;
|
|
|
|
|
PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
|
|
|
|
|
return channel->Receive(data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsClosed() {
|
|
|
|
|
if (IsInitialized()) {
|
|
|
|
|
return holder_->IsClosed();
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
return holder_->IsClosed();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CanSend() {
|
|
|
|
|
if (IsInitialized()) {
|
|
|
|
|
return holder_->CanSend();
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
return holder_->CanSend();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CanReceive() {
|
|
|
|
|
if (IsInitialized()) {
|
|
|
|
|
return holder_->CanReceive();
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
return holder_->CanReceive();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void close() {
|
|
|
|
|
if (IsInitialized()) holder_->Close();
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
holder_->Close();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t Cap() {
|
|
|
|
|
if (IsInitialized()) return holder_->Cap();
|
|
|
|
|
return -1;
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
return holder_->Cap();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Lock() {
|
|
|
|
|
if (IsInitialized()) holder_->Lock();
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
holder_->Lock();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Unlock() {
|
|
|
|
|
if (IsInitialized()) holder_->Unlock();
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
holder_->Unlock();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void AddToSendQ(const void* referrer, T* data,
|
|
|
|
|
std::shared_ptr<std::condition_variable_any> cond,
|
|
|
|
|
std::function<bool(ChannelAction)> cb) {
|
|
|
|
|
if (IsInitialized()) {
|
|
|
|
|
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
|
|
|
|
|
if (channel != nullptr) {
|
|
|
|
|
channel->AddToSendQ(referrer, data, cond, cb);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
|
|
|
|
|
if (channel != nullptr) {
|
|
|
|
|
channel->AddToSendQ(referrer, data, cond, cb);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -154,26 +166,31 @@ class ChannelHolder {
|
|
|
|
|
void AddToReceiveQ(const void* referrer, T* data,
|
|
|
|
|
std::shared_ptr<std::condition_variable_any> cond,
|
|
|
|
|
std::function<bool(ChannelAction)> cb) {
|
|
|
|
|
if (IsInitialized()) {
|
|
|
|
|
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
|
|
|
|
|
if (channel != nullptr) {
|
|
|
|
|
channel->AddToReceiveQ(referrer, data, cond, cb);
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
|
|
|
|
|
if (channel != nullptr) {
|
|
|
|
|
channel->AddToReceiveQ(referrer, data, cond, cb);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RemoveFromSendQ(const void* referrer) {
|
|
|
|
|
if (IsInitialized()) holder_->RemoveFromSendQ(referrer);
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
holder_->RemoveFromSendQ(referrer);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RemoveFromReceiveQ(const void* referrer) {
|
|
|
|
|
if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer);
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
holder_->RemoveFromReceiveQ(referrer);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline bool IsInitialized() const { return holder_ != nullptr; }
|
|
|
|
|
|
|
|
|
|
inline const std::type_index Type() {
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true);
|
|
|
|
|
PADDLE_ENFORCE_EQ(IsInitialized(), true,
|
|
|
|
|
"The Channel hasn't been initialized");
|
|
|
|
|
return holder_->Type();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|