@ -106,6 +106,7 @@ void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE ( queue_size > 0 , " Illegal queue size: %d. " , queue_size ) ;
PADDLE_ENFORCE ( queue_size > 0 , " Illegal queue size: %d. " , queue_size ) ;
queue_size_ = queue_size ;
queue_size_ = queue_size ;
queue_ = paddle : : framework : : MakeChannel < T > ( ) ;
queue_ = paddle : : framework : : MakeChannel < T > ( ) ;
queue_ - > SetCapacity ( queue_size ) ;
}
}
template < typename T >
template < typename T >
@ -301,7 +302,8 @@ void MultiSlotDataFeed::Init(
paddle : : framework : : MultiSlotDesc multi_slot_desc =
paddle : : framework : : MultiSlotDesc multi_slot_desc =
data_feed_desc . multi_slot_desc ( ) ;
data_feed_desc . multi_slot_desc ( ) ;
SetBatchSize ( data_feed_desc . batch_size ( ) ) ;
SetBatchSize ( data_feed_desc . batch_size ( ) ) ;
SetQueueSize ( data_feed_desc . batch_size ( ) ) ;
// temporarily set queue size = batch size * 100
SetQueueSize ( data_feed_desc . batch_size ( ) * 100 ) ;
size_t all_slot_num = multi_slot_desc . slots_size ( ) ;
size_t all_slot_num = multi_slot_desc . slots_size ( ) ;
all_slots_ . resize ( all_slot_num ) ;
all_slots_ . resize ( all_slot_num ) ;
all_slots_type_ . resize ( all_slot_num ) ;
all_slots_type_ . resize ( all_slot_num ) ;