@ -64,6 +64,11 @@ void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) {
}
}
extern const std : : shared_ptr < VariableWrapper > & GetVariableWrapper (
const std : : shared_ptr < paddle : : imperative : : VarBase > & var ) ;
extern const std : : shared_ptr < VariableWrapper > & GetVariableWrapper (
const std : : shared_ptr < VariableWrapper > & var ) ;
template < typename VarType >
std : : shared_ptr < NameVarMap < VarType > > PrepareData (
const framework : : OperatorWithKernel & op , const NameVarMap < VarType > & ins ,
@ -82,23 +87,50 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
} else {
VLOG ( 3 ) < < " Transform Variable " < < var_base - > Name ( ) < < " from "
< < kernel_type_for_var < < " to " < < expected_kernel_key ;
framework : : Tensor out ;
TransformData ( expected_kernel_key , kernel_type_for_var , * tensor ,
& out ) ;
if ( NeedTransformDataType ( kernel_type_for_var , expected_kernel_key ) ) {
// To avoid NameVarMap copy construction overhead in general
// scenarios, if inplace transformed, return original input directly
if ( GetVariableWrapper ( var_base ) - > hasCacheKey ( expected_kernel_key ) ) {
VLOG ( 3 ) < < " Hit variable_wrapper cache: key= "
< < expected_kernel_key ;
std : : shared_ptr < VariableWrapper > cache_var =
GetVariableWrapper ( var_base ) - > getCacheValue (
expected_kernel_key ) ;
if ( tmp_ins_ptr = = nullptr ) {
tmp_ins_ptr = std : : make_shared < NameVarMap < VarType > > ( ins ) ;
}
const auto * tensor = GetTensorFromVar ( cache_var - > Var ( ) ) ;
auto tmp_var = std : : make_shared < VarType > ( var_base - > Name ( ) ) ;
tmp_var - > SetType ( var_base - > Type ( ) ) ;
SetTensorToVariable ( var_base - > Var ( ) , out , tmp_var - > MutableVar ( ) ) ;
SetTensorToVariable ( cache_var - > Var ( ) , * tensor ,
tmp_var - > MutableVar ( ) ) ;
( * tmp_ins_ptr ) [ name_pair . first ] [ i ] = tmp_var ;
} else {
// if dtype is same, transform inplace will not change the original
// value, transform inplace to avoid multiple copy
SetTensorToVariable ( var_base - > Var ( ) , out , var_base - > MutableVar ( ) ) ;
framework : : Tensor out ;
TransformData ( expected_kernel_key , kernel_type_for_var , * tensor ,
& out ) ;
if ( NeedTransformDataType ( kernel_type_for_var ,
expected_kernel_key ) ) {
// To avoid NameVarMap copy construction overhead in general
// scenarios, if inplace transformed, return original input
// directly
if ( tmp_ins_ptr = = nullptr ) {
tmp_ins_ptr = std : : make_shared < NameVarMap < VarType > > ( ins ) ;
}
auto tmp_var = std : : make_shared < VarType > ( var_base - > Name ( ) ) ;
tmp_var - > SetType ( var_base - > Type ( ) ) ;
SetTensorToVariable ( var_base - > Var ( ) , out , tmp_var - > MutableVar ( ) ) ;
( * tmp_ins_ptr ) [ name_pair . first ] [ i ] = tmp_var ;
GetVariableWrapper ( var_base ) - > setCacheValue (
expected_kernel_key , GetVariableWrapper ( tmp_var ) ) ;
VLOG ( 3 ) < < " Set cache to variable_wrapper: key= "
< < expected_kernel_key ;
} else {
// if dtype is same, transform inplace will not change the
// original
// value, transform inplace to avoid multiple copy
SetTensorToVariable ( var_base - > Var ( ) , out , var_base - > MutableVar ( ) ) ;
}
}
}
}