@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License . */
# pragma once
# include <atomic>
# include <condition_variable>
# include <deque>
# include <mutex>
@ -42,7 +43,7 @@ class Buffered : public paddle::framework::Channel<T> {
std : : condition_variable empty_cond_var_ ;
std : : condition_variable full_cond_var_ ;
std : : deque < T > channel_ ;
bool closed_ ;
std : : atomic < bool > closed_ { false } ;
Buffered ( size_t cap ) : cap_ ( cap ) , closed_ ( false ) {
PADDLE_ENFORCE_GT ( cap , 0 ) ;
@ -53,10 +54,13 @@ class Buffered : public paddle::framework::Channel<T> {
template < typename T >
bool Buffered < T > : : Send ( T * item ) {
bool ret = false ;
if ( closed_ ) {
return ret ;
}
std : : unique_lock < std : : mutex > lock ( mu_ ) ;
full_cond_var_ . wait ( lock ,
[ this ] ( ) { return channel_ . size ( ) < cap_ | | closed_ ; } ) ;
bool ret = false ;
if ( ! closed_ ) {
channel_ . push_back ( std : : move ( * item ) ) ;
lock . unlock ( ) ;
@ -82,6 +86,9 @@ bool Buffered<T>::Receive(T* item) {
template < typename T >
void Buffered < T > : : Close ( ) {
if ( closed_ ) {
return ;
}
std : : unique_lock < std : : mutex > lock ( mu_ ) ;
closed_ = true ;
NotifyAllParticipants ( & lock ) ;