@ -24,6 +24,27 @@ limitations under the License. */
namespace paddle {
namespace paddle {
namespace unittest {
static std : : unique_ptr < std : : function < void ( size_t /*poolActualSize */ ) > >
OnPoolFilled ;
namespace pydp2 {
void setOnPoolFilledHook ( const std : : function < void ( size_t ) > & callback ) {
OnPoolFilled . reset ( new std : : function < void ( size_t ) > ( ) ) ;
* OnPoolFilled = callback ;
}
void clearOnPoolFilledHook ( ) {
OnPoolFilled . reset ( ) ;
}
} // namespace pydp2
} // namespace unittest
/**
/**
* Slot type
* Slot type
*/
*/
@ -179,6 +200,7 @@ public:
* Ctor
* Ctor
*/
*/
PyDataProvider2 ( const DataConfig & config ,
PyDataProvider2 ( const DataConfig & config ,
const ModelConfig & modelConfig ,
bool useGpu )
bool useGpu )
: DataProvider ( config , useGpu ) , callingContextCreated_ ( 2 ) {
: DataProvider ( config , useGpu ) , callingContextCreated_ ( 2 ) {
auto & args = config . load_data_args ( ) ;
auto & args = config . load_data_args ( ) ;
@ -192,6 +214,12 @@ public:
py : : DictHelper kwargsDict ( kwargs ) ;
py : : DictHelper kwargsDict ( kwargs ) ;
kwargsDict . setBool ( " is_train " , ! config . for_test ( ) ) ;
kwargsDict . setBool ( " is_train " , ! config . for_test ( ) ) ;
std : : vector < std : : string > inputs ;
inputs . reserve ( modelConfig . input_layer_names ( ) . size ( ) ) ;
std : : copy ( modelConfig . input_layer_names ( ) . begin ( ) ,
modelConfig . input_layer_names ( ) . end ( ) ,
std : : back_inserter ( inputs ) ) ;
kwargsDict . setStringList ( " input_order " , inputs ) ;
// kwargs is keyword arguemts to create object.
// kwargs is keyword arguemts to create object.
this - > createPyDataObj ( config . load_data_module ( ) ,
this - > createPyDataObj ( config . load_data_module ( ) ,
@ -199,7 +227,7 @@ public:
config . files ( ) ,
config . files ( ) ,
std : : move ( kwargs ) ) ;
std : : move ( kwargs ) ) ;
DBG < < " Instance " < < instance_ . get ( ) < < " loaded. " ;
DBG < < " Instance " < < instance_ . get ( ) < < " loaded. " ;
this - > readPyFields ( ) ;
this - > readPyFields ( config . for_test ( ) ) ;
DBG < < " Py Field Done " ;
DBG < < " Py Field Done " ;
}
}
@ -253,14 +281,28 @@ private:
CHECK_PY ( instance_ ) < < " Cannot Create instance " ;
CHECK_PY ( instance_ ) < < " Cannot Create instance " ;
}
}
void readPyFields ( ) {
void readPyFields ( bool testing ) {
py : : ObjectHelper self ( this - > instance_ ) ;
py : : ObjectHelper self ( this - > instance_ ) ;
this - > skipShuffle_ = ! self . getBoolAttr ( " should_shuffle " ) ;
bool ok ;
bool ok ;
this - > skipShuffle_ = ! self . getBoolAttr ( " should_shuffle " ,
& ok /*isBoolType*/ ) ;
if ( ! ok ) {
this - > skipShuffle_ = testing ; // shuffle when is training, skip shuffle
// when is testing.
}
DBG < < " Provider Skip Shuffle " < < this - > skipShuffle_ ;
this - > poolSize_ = self . getIntAttr < size_t > ( " pool_size " , & ok ) ;
this - > poolSize_ = self . getIntAttr < size_t > ( " pool_size " , & ok ) ;
if ( ! ok ) {
if ( ! ok ) {
this - > poolSize_ = - 1UL ;
this - > poolSize_ = - 1UL ;
}
}
this - > minPoolSize_ = self . getIntAttr < size_t > ( " min_pool_size " , & ok ) ;
if ( ! ok ) {
this - > minPoolSize_ = - 1UL ;
}
this - > minPoolSize_ = std : : min ( this - > poolSize_ , this - > minPoolSize_ ) ;
this - > canOverBatchSize_ = self . getBoolAttr ( " can_over_batch_size " ) ;
this - > canOverBatchSize_ = self . getBoolAttr ( " can_over_batch_size " ) ;
calcBatchSize_ . reset ( self . getAttr ( " calc_batch_size " ) ) ;
calcBatchSize_ . reset ( self . getAttr ( " calc_batch_size " ) ) ;
@ -307,7 +349,6 @@ private:
}
}
void loadThread ( ) {
void loadThread ( ) {
callingContexts_ . reserve ( fileLists_ . size ( ) ) ;
DBG < < " Creating context " ;
DBG < < " Creating context " ;
for ( auto & filename : fileLists_ ) {
for ( auto & filename : fileLists_ ) {
PyGuard g ;
PyGuard g ;
@ -332,7 +373,14 @@ private:
bool atEnd ;
bool atEnd ;
data = py : : iterNext ( callingContexts_ [ cid ] , & atEnd ) ;
data = py : : iterNext ( callingContexts_ [ cid ] , & atEnd ) ;
if ( atEnd | | data = = nullptr ) {
if ( atEnd | | data = = nullptr ) {
callingContexts_ . erase ( callingContexts_ . begin ( ) + cid ) ;
if ( cid ! = 0 ) {
std : : swap ( callingContexts_ [ cid ] , callingContexts_ [ 0 ] ) ;
cid = 0 ;
}
{
PyGuard g ;
callingContexts_ . pop_front ( ) ;
}
this - > pullCV_ . notify_all ( ) ;
this - > pullCV_ . notify_all ( ) ;
continue ;
continue ;
}
}
@ -354,11 +402,7 @@ private:
if ( this - > loadThread_ ) { // wait poolActualSize < poolSize;
if ( this - > loadThread_ ) { // wait poolActualSize < poolSize;
std : : unique_lock < std : : mutex > l ( mtx_ ) ;
std : : unique_lock < std : : mutex > l ( mtx_ ) ;
pushCV_ . wait ( l , [ this , additionalBatchSize ] {
pushCV_ . wait ( l , [ this , additionalBatchSize ] {
if ( this - > canOverBatchSize_ ) {
return this - > poolActualSize_ < poolSize_ ;
return this - > poolActualSize_ < poolSize_ ;
} else {
return this - > poolActualSize_ + additionalBatchSize < poolSize_ ;
}
} ) ;
} ) ;
}
}
@ -402,7 +446,7 @@ private:
private :
private :
std : : unique_ptr < std : : thread > loadThread_ ;
std : : unique_ptr < std : : thread > loadThread_ ;
std : : atomic < bool > exit_ ;
std : : atomic < bool > exit_ ;
std : : vector < PyObjectPtr > callingContexts_ ;
std : : deque < PyObjectPtr > callingContexts_ ;
std : : deque < PyObjectPtr > dataPool_ ;
std : : deque < PyObjectPtr > dataPool_ ;
size_t poolActualSize_ ;
size_t poolActualSize_ ;
std : : condition_variable pushCV_ ;
std : : condition_variable pushCV_ ;
@ -413,6 +457,7 @@ private:
PyObjectPtr instance_ ;
PyObjectPtr instance_ ;
size_t poolSize_ ;
size_t poolSize_ ;
size_t minPoolSize_ ;
bool canOverBatchSize_ ;
bool canOverBatchSize_ ;
PyObjectPtr calcBatchSize_ ;
PyObjectPtr calcBatchSize_ ;
PyObjectPtr generator_ ;
PyObjectPtr generator_ ;
@ -478,8 +523,13 @@ public:
// data pool ready.
// data pool ready.
std : : unique_lock < std : : mutex > l ( mtx_ ) ;
std : : unique_lock < std : : mutex > l ( mtx_ ) ;
pullCV_ . wait ( l , [ this , & size ] {
pullCV_ . wait ( l , [ this , & size ] {
return this - > poolActualSize_ > = size | | callingContexts_ . empty ( ) ;
return this - > poolActualSize_ > = std : : max ( size , this - > minPoolSize_ )
| | callingContexts_ . empty ( ) ;
} ) ;
} ) ;
if ( unittest : : OnPoolFilled ) {
( * unittest : : OnPoolFilled ) ( this - > poolActualSize_ ) ;
}
}
}
std : : deque < PyObjectPtr > data ;
std : : deque < PyObjectPtr > data ;
size_t bsize = 0 ;
size_t bsize = 0 ;
@ -495,7 +545,8 @@ public:
std : : deque < PyObjectPtr > & pool = * poolPtr ;
std : : deque < PyObjectPtr > & pool = * poolPtr ;
while ( bsize < size & & ! pool . empty ( ) ) {
while ( bsize < size & & ! pool . empty ( ) ) {
{ // move data from pool to data
{
// move data from pool to data
std : : lock_guard < std : : mutex > guard ( mtx_ ) ;
std : : lock_guard < std : : mutex > guard ( mtx_ ) ;
if ( skipShuffle_ ) {
if ( skipShuffle_ ) {
size_t i = 0 ;
size_t i = 0 ;
@ -505,14 +556,13 @@ public:
} else { // when shuffle, use swap to drop only last pool element.
} else { // when shuffle, use swap to drop only last pool element.
size_t i = ThreadLocalRand : : rand ( ) % pool . size ( ) ;
size_t i = ThreadLocalRand : : rand ( ) % pool . size ( ) ;
CHECK ( pool [ i ] ! = nullptr ) ;
CHECK ( pool [ i ] ! = nullptr ) ;
if ( i ! = pool . size ( ) - 1 ) {
if ( i ! = 0 ) {
std : : swap ( pool [ i ] , pool . back ( ) ) ;
std : : swap ( pool [ i ] , pool . front ( ) ) ;
}
data . emplace_back ( std : : move ( pool . back ( ) ) ) ;
pool . pop_back ( ) ;
}
}
data . emplace_back ( std : : move ( pool . front ( ) ) ) ;
pool . pop_front ( ) ;
}
}
{
if ( calcBatchSize_ ) { // custom calc batch size.
if ( calcBatchSize_ ) { // custom calc batch size.
PyGuard guard ;
PyGuard guard ;
Py_INCREF ( data . back ( ) . get ( ) ) ;
Py_INCREF ( data . back ( ) . get ( ) ) ;
@ -521,8 +571,17 @@ public:
calcBatchSize . getArgs ( ) . set ( 0 , data . back ( ) ) ;
calcBatchSize . getArgs ( ) . set ( 0 , data . back ( ) ) ;
PyObjectPtr customBatchSize ( calcBatchSize ( ) ) ;
PyObjectPtr customBatchSize ( calcBatchSize ( ) ) ;
bool ok ;
bool ok ;
bsize + = py : : castInt < size_t > ( customBatchSize . get ( ) , & ok ) ;
size_t tmp = py : : castInt < size_t > ( customBatchSize . get ( ) , & ok ) ;
CHECK ( ok ) < < " calc_batch_size must return int " ;
CHECK ( ok ) < < " calc_batch_size must return int " ;
if ( bsize + tmp > size & & ! canOverBatchSize_ ) {
// Put data back.
pool . push_front ( std : : move ( data . back ( ) ) ) ;
data . pop_back ( ) ;
break ;
} else {
bsize + = tmp ;
}
} else {
} else {
bsize + = 1 ;
bsize + = 1 ;
}
}
@ -598,7 +657,6 @@ public:
} else {
} else {
* batch = cpuBatch ;
* batch = cpuBatch ;
}
}
return bsize ;
return bsize ;
}
}
} ;
} ;
@ -606,7 +664,8 @@ public:
std : : unordered_set < uintptr_t > PyDataProvider2 : : gModuleClsPtrs_ ;
std : : unordered_set < uintptr_t > PyDataProvider2 : : gModuleClsPtrs_ ;
PyObjectPtr PyDataProvider2 : : zeroTuple_ ( PyTuple_New ( 0 ) ) ;
PyObjectPtr PyDataProvider2 : : zeroTuple_ ( PyTuple_New ( 0 ) ) ;
REGISTER_DATA_PROVIDER ( py2 , PyDataProvider2 ) ;
REGISTER_DATA_PROVIDER_EX ( py2 , PyDataProvider2 ) ;
/**
/**
* Scanner for dense slot .
* Scanner for dense slot .