@ -361,7 +361,7 @@ static bool VarIsTensor(const Variable& var) {
return var . IsType < LoDTensor > ( ) | | var . IsType < SelectedRows > ( ) ;
}
const Tensor * Get TensorFromVar( const Variable & var ) {
const Tensor * Get LoD TensorOrSelectedRowsValue FromVar( const Variable & var ) {
if ( var . IsType < LoDTensor > ( ) ) {
return static_cast < const Tensor * > ( & ( var . Get < LoDTensor > ( ) ) ) ;
} else if ( var . IsType < SelectedRows > ( ) ) {
@ -372,7 +372,7 @@ const Tensor* GetTensorFromVar(const Variable& var) {
}
}
static Tensor * GetMutable TensorFromVar( Variable * var ) {
Tensor * GetMutable LoD TensorOrSelectedRowsValue FromVar( Variable * var ) {
if ( var - > IsType < LoDTensor > ( ) ) {
return var - > GetMutable < LoDTensor > ( ) ;
} else if ( var - > IsType < SelectedRows > ( ) ) {
@ -417,8 +417,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
template < >
const Tensor * ExecutionContext : : Input < Tensor > ( const std : : string & name ) const {
auto * var = InputVar ( name ) ;
return var = = nullptr ? nullptr : GetTensorFromVar ( * var ) ;
return Input < LoDTensor > ( name ) ;
}
template < >
@ -428,17 +427,21 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
std : : vector < const Tensor * > res ;
res . reserve ( names . size ( ) ) ;
std : : transform ( names . begin ( ) , names . end ( ) , std : : back_inserter ( res ) ,
[ & ] ( const std : : string & sub_name ) {
[ & ] ( const std : : string & sub_name ) - > const Tensor * {
auto var = scope_ . FindVar ( sub_name ) ;
return var = = nullptr ? nullptr : GetTensorFromVar ( * var ) ;
if ( var = = nullptr ) return nullptr ;
PADDLE_ENFORCE (
var - > IsType < LoDTensor > ( ) ,
" %s should be LoDTensor, but the received type is %s " ,
sub_name , var - > Type ( ) . name ( ) ) ;
return & ( var - > Get < LoDTensor > ( ) ) ;
} ) ;
return res ;
}
template < >
Tensor * ExecutionContext : : Output < Tensor > ( const std : : string & name ) const {
auto var = OutputVar ( name ) ;
return var = = nullptr ? nullptr : GetMutableTensorFromVar ( var ) ;
return Output < LoDTensor > ( name ) ;
}
template < >
@ -448,10 +451,14 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
std : : vector < Tensor * > res ;
res . reserve ( names . size ( ) ) ;
std : : transform ( names . begin ( ) , names . end ( ) , std : : back_inserter ( res ) ,
[ & ] ( const std : : string & sub_name ) {
[ & ] ( const std : : string & sub_name ) - > Tensor * {
auto var = scope_ . FindVar ( sub_name ) ;
return var = = nullptr ? nullptr
: GetMutableTensorFromVar ( var ) ;
if ( var = = nullptr ) return nullptr ;
PADDLE_ENFORCE (
var - > IsType < LoDTensor > ( ) ,
" %s should be LoDTensor, but the received type is %s " ,
sub_name , var - > Type ( ) . name ( ) ) ;
return var - > GetMutable < LoDTensor > ( ) ;
} ) ;
return res ;
}
@ -771,11 +778,12 @@ void OperatorWithKernel::TransferInplaceVarsBack(
const Scope & transfer_scope ) const {
for ( auto & var_name : inplace_vars ) {
VLOG ( 3 ) < < " share inplace var " + var_name + " back to it's original scope " ;
auto * original_tensor = GetMutableTensorFromVar ( scope . FindVar ( var_name ) ) ;
auto * original_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar ( scope . FindVar ( var_name ) ) ;
auto * var = transfer_scope . FindVar ( var_name ) ;
PADDLE_ENFORCE ( var ! = nullptr , " The var[%s] should not be nullptr " ,
var_name ) ;
auto * transformed_tensor = Get TensorFromVar( * var ) ;
auto * transformed_tensor = Get LoD TensorOrSelectedRowsValue FromVar( * var ) ;
original_tensor - > ShareDataWith ( * transformed_tensor ) ;
}
}
@ -792,7 +800,7 @@ Scope* OperatorWithKernel::TryTransferData(
continue ;
}
auto * tensor_in = Get TensorFromVar( * var ) ;
auto * tensor_in = Get LoD TensorOrSelectedRowsValue FromVar( * var ) ;
if ( ! tensor_in - > IsInitialized ( ) ) {
continue ;
}