@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if ( tensor . memory_size ( ) = = 0 ) {
return ;
}
if ( tensor . type ( ) . hash_code ( ) ! = typeid ( float ) . hash_code ( ) & &
tensor . type ( ) . hash_code ( ) ! = typeid ( double ) . hash_code ( ) ) {
if ( tensor . type ( ) . hash_code ( ) ! = typeid ( float ) . hash_code ( ) & & // NOLINT
tensor . type ( ) . hash_code ( ) ! = typeid ( double ) . hash_code ( ) ) { // NOLINT
return ;
}
PADDLE_ENFORCE ( ! framework : : TensorContainsInf ( tensor ) ,
@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators (
const BlockDesc & block ,
std : : map < std : : string , const LoDTensor * > & feed_targets ,
const std : : map < std : : string , const LoDTensor * > & feed_targets ,
const std : : string & feed_holder_name ) {
size_t feed_count = 0 ;
for ( auto * op : block . AllOps ( ) ) {
if ( op - > Type ( ) = = kFeedOpType ) {
feed_count + + ;
// The input variable's name of feed_op should be feed_holder_name.
PADDLE_ENFORCE_EQ ( op - > Input ( " X " ) [ 0 ] , feed_holder_name ,
" Input to feed op should be '%s' " , feed_holder_name ) ;
std : : string feed_target_name = op - > Output ( " Out " ) [ 0 ] ;
@ -166,7 +167,8 @@ static bool has_feed_operators(
feed_count , feed_targets . size ( ) ,
" The number of feed operators should match 'feed_targets' " ) ;
// When feed operator are present, so should be feed_holder
if ( ! feed_holder_name . empty ( ) ) {
// When feed operator are present, so should be feed_holder.
auto var = block . FindVar ( feed_holder_name ) ;
PADDLE_ENFORCE_NOT_NULL ( var , " Block should already have a '%s' variable " ,
feed_holder_name ) ;
@ -174,6 +176,7 @@ static bool has_feed_operators(
" '%s' variable should be 'FEED_MINIBATCH' type " ,
feed_holder_name ) ;
}
}
return feed_count > 0 ;
}
@ -185,12 +188,14 @@ static bool has_feed_operators(
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators (
const BlockDesc & block , std : : map < std : : string , LoDTensor * > & fetch_targets ,
const BlockDesc & block ,
const std : : map < std : : string , LoDTensor * > & fetch_targets ,
const std : : string & fetch_holder_name ) {
size_t fetch_count = 0 ;
for ( auto * op : block . AllOps ( ) ) {
if ( op - > Type ( ) = = kFetchOpType ) {
fetch_count + + ;
// The output variable's name of fetch_op should be fetch_holder_name.
PADDLE_ENFORCE_EQ ( op - > Output ( " Out " ) [ 0 ] , fetch_holder_name ,
" Output of fetch op should be '%s' " , fetch_holder_name ) ;
std : : string fetch_target_name = op - > Input ( " X " ) [ 0 ] ;
@ -206,7 +211,8 @@ static bool has_fetch_operators(
fetch_count , fetch_targets . size ( ) ,
" The number of fetch operators should match 'fetch_targets' " ) ;
// When fetch operator are present, so should be fetch_holder
if ( ! fetch_holder_name . empty ( ) ) {
// When fetch operator are present, so should be fetch_holder.
auto var = block . FindVar ( fetch_holder_name ) ;
PADDLE_ENFORCE_NOT_NULL ( var , " Block should already have a '%s' variable " ,
fetch_holder_name ) ;
@ -214,6 +220,7 @@ static bool has_fetch_operators(
" '%s' variable should be 'FETCH_LIST' type " ,
fetch_holder_name ) ;
}
}
return fetch_count > 0 ;
}
@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
// map the data of feed_targets to feed_holder
for ( auto * op : global_block - > AllOps ( ) ) {
if ( op - > Type ( ) = = kFeedOpType ) {
std : : string feed_target_name = op - > Output ( " Out " ) [ 0 ] ;
int idx = boost : : get < int > ( op - > GetAttr ( " col " ) ) ;
SetFeedVariable ( scope , * feed_targets [ feed_target_name ] , feed_holder_name ,
idx ) ;
}
}
if ( ! has_fetch_ops ) {
// create fetch_holder variable
auto * fetch_holder = global_block - > Var ( fetch_holder_name ) ;
@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
Run ( * copy_program , scope , 0 , create_vars , create_vars ) ;
// obtain the data of fetch_targets from fetch_holder
for ( auto * op : global_block - > AllOps ( ) ) {
if ( op - > Type ( ) = = kFetchOpType ) {
std : : string fetch_target_name = op - > Input ( " X " ) [ 0 ] ;
int idx = boost : : get < int > ( op - > GetAttr ( " col " ) ) ;
* fetch_targets [ fetch_target_name ] =
GetFetchVariable ( * scope , fetch_holder_name , idx ) ;
}
}
auto ctx = Prepare ( * copy_program , 0 ) ;
RunPreparedContext ( ctx . get ( ) , scope , feed_targets , fetch_targets , create_vars ,
feed_holder_name , fetch_holder_name ) ;
}
std : : unique_ptr < ExecutorPrepareContext > Executor : : Prepare (
@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
}
void Executor : : RunPreparedContext (
ExecutorPrepareContext * ctx , Scope * scope ,
std : : map < std : : string , const LoDTensor * > & feed_targets ,
std : : map < std : : string , LoDTensor * > & fetch_targets , bool create_vars ,
const std : : string & feed_holder_name , const std : : string & fetch_holder_name ) {
auto & global_block = ctx - > prog_ . Block ( ctx - > block_id_ ) ;
PADDLE_ENFORCE (
has_feed_operators ( global_block , feed_targets , feed_holder_name ) ,
" Program in ExecutorPrepareContext should has feed_ops. " ) ;
PADDLE_ENFORCE (
has_fetch_operators ( global_block , fetch_targets , fetch_holder_name ) ,
" Program in the prepared context should has fetch_ops. " ) ;
// map the data of feed_targets to feed_holder
for ( auto * op : global_block . AllOps ( ) ) {
if ( op - > Type ( ) = = kFeedOpType ) {
std : : string feed_target_name = op - > Output ( " Out " ) [ 0 ] ;
int idx = boost : : get < int > ( op - > GetAttr ( " col " ) ) ;
SetFeedVariable ( scope , * feed_targets [ feed_target_name ] , feed_holder_name ,
idx ) ;
}
}
RunPreparedContext ( ctx , scope , create_vars , create_vars ) ;
// obtain the data of fetch_targets from fetch_holder
for ( auto * op : global_block . AllOps ( ) ) {
if ( op - > Type ( ) = = kFetchOpType ) {
std : : string fetch_target_name = op - > Input ( " X " ) [ 0 ] ;
int idx = boost : : get < int > ( op - > GetAttr ( " col " ) ) ;
* fetch_targets [ fetch_target_name ] =
GetFetchVariable ( * scope , fetch_holder_name , idx ) ;
}
}
}
} // namespace framework
} // namespace paddle