@ -20,8 +20,11 @@ namespace reader {
class BatchReader : public framework : : DecoratedReader {
public :
BatchReader ( const std : : shared_ptr < ReaderBase > & reader , int batch_size )
: DecoratedReader ( reader ) , batch_size_ ( batch_size ) {
BatchReader ( const std : : shared_ptr < ReaderBase > & reader , int batch_size ,
bool discard_leftover )
: DecoratedReader ( reader ) ,
batch_size_ ( batch_size ) ,
discard_leftover_ ( discard_leftover ) {
buffer_ . reserve ( batch_size_ ) ;
Start ( ) ;
}
@ -30,6 +33,7 @@ class BatchReader : public framework::DecoratedReader {
private :
int batch_size_ ;
bool discard_leftover_ ;
std : : vector < std : : vector < framework : : LoDTensor > > buffer_ ;
} ;
@ -47,8 +51,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}
const auto & underlying_reader = scope . FindVar ( Input ( " UnderlyingReader " ) )
- > Get < framework : : ReaderHolder > ( ) ;
out - > Reset (
new BatchReader ( underlying_reader . Get ( ) , Attr < int > ( " batch_size " ) ) ) ;
out - > Reset ( new BatchReader ( underlying_reader . Get ( ) , Attr < int > ( " batch_size " ) ,
Attr < bool > ( " discard_leftover " ) ) ) ;
}
} ;
@ -58,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
AddAttr < int > ( " batch_size " ,
" How many instances the batch reader yields each time. " )
. GreaterThan ( 0 ) ;
AddAttr < bool > ( " discard_leftover " ,
" If true, the leftover instances that are not enough for a "
" new batch will be discarded. " )
. SetDefault ( true ) ;
AddComment ( R " DOC(
CreateBatchReader Operator
@ -78,6 +86,9 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
break ;
}
}
if ( discard_leftover_ & & buffer_ . size ( ) < batch_size_ ) {
buffer_ . clear ( ) ;
}
// Concat instances
out - > clear ( ) ;
if ( buffer_ . empty ( ) ) {