@ -257,7 +257,8 @@ framework::VarDesc *OpTester::Var(const std::string &name) {
template < typename T >
void OpTester : : SetupTensor ( framework : : LoDTensor * tensor ,
const std : : vector < int64_t > & shape , T lower , T upper ,
const std : : string & initializer ) {
const std : : string & initializer ,
const std : : string & filename ) {
static unsigned int seed = 100 ;
std : : mt19937 rng ( seed + + ) ;
std : : uniform_real_distribution < double > uniform_dist ( 0 , 1 ) ;
@ -280,12 +281,20 @@ void OpTester::SetupTensor(framework::LoDTensor *tensor,
}
} else if ( initializer = = " natural " ) {
for ( int i = 0 ; i < cpu_tensor . numel ( ) ; + + i ) {
cpu_ptr [ i ] = lower + i ;
cpu_ptr [ i ] = static_cast < T > ( lower + i ) ;
}
} else if ( initializer = = " zeros " ) {
for ( int i = 0 ; i < cpu_tensor . numel ( ) ; + + i ) {
cpu_ptr [ i ] = 0 ;
cpu_ptr [ i ] = static_cast < T > ( 0 ) ;
}
} else if ( initializer = = " file " ) {
std : : ifstream is ( filename ) ;
for ( size_t i = 0 ; i < cpu_tensor . numel ( ) ; + + i ) {
T value ;
is > > value ;
cpu_ptr [ i ] = static_cast < T > ( value ) ;
}
is . close ( ) ;
} else {
PADDLE_THROW ( " Unsupported initializer %s. " , initializer . c_str ( ) ) ;
}
@ -325,15 +334,19 @@ void OpTester::CreateVariables(framework::Scope *scope) {
auto * tensor = var - > GetMutable < framework : : LoDTensor > ( ) ;
const auto & data_type = var_desc - > GetDataType ( ) ;
if ( data_type = = framework : : proto : : VarType : : INT32 ) {
SetupTensor < int > ( tensor , shape , 0 , 1 , item . second . initializer ) ;
SetupTensor < int > ( tensor , shape , 0 , 1 , item . second . initializer ,
item . second . filename ) ;
} else if ( data_type = = framework : : proto : : VarType : : INT64 ) {
SetupTensor < int64_t > ( tensor , shape , 0 , 1 , item . second . initializer ) ;
SetupTensor < int64_t > ( tensor , shape , 0 , 1 , item . second . initializer ,
item . second . filename ) ;
} else if ( data_type = = framework : : proto : : VarType : : FP32 ) {
SetupTensor < float > ( tensor , shape , static_cast < float > ( 0.0 ) ,
static_cast < float > ( 1.0 ) , item . second . initializer ) ;
static_cast < float > ( 1.0 ) , item . second . initializer ,
item . second . filename ) ;
} else if ( data_type = = framework : : proto : : VarType : : FP64 ) {
SetupTensor < double > ( tensor , shape , static_cast < double > ( 0.0 ) ,
static_cast < double > ( 1.0 ) , item . second . initializer ) ;
static_cast < double > ( 1.0 ) , item . second . initializer ,
item . second . filename ) ;
} else {
PADDLE_THROW ( " Unsupported dtype %d. " , data_type ) ;
}