@ -161,7 +161,7 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
}
}
VLOG ( 5 ) < < " Init Tensor as: / name: " < < name
VLOG ( 5 ) < < " Init Tensor as: / name: " < < name
< < " / persistable: " < < persistable < < " / zero_copy: " < < zero_copy
< < " / persistable: " < < persistable < < " / zero_copy: " < < zero_copy
< < " / stop_gradient: " < < stop_gradient ;
< < " / stop_gradient: " < < stop_gradient < < " / at " < < place ;
new ( self ) imperative : : VarBase ( name ) ;
new ( self ) imperative : : VarBase ( name ) ;
self - > SetPersistable ( persistable ) ;
self - > SetPersistable ( persistable ) ;
auto * tensor = self - > MutableVar ( ) - > GetMutable < framework : : LoDTensor > ( ) ;
auto * tensor = self - > MutableVar ( ) - > GetMutable < framework : : LoDTensor > ( ) ;
@ -175,8 +175,8 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
static void InitVarBaseFromNumpyWithArgDefault ( imperative : : VarBase * self ,
static void InitVarBaseFromNumpyWithArgDefault ( imperative : : VarBase * self ,
const py : : array & array ) {
const py : : array & array ) {
VLOG ( 4 ) < < " Init VarBase from numpy: " ;
auto place = imperative : : GetCurrentTracer ( ) - > ExpectedPlace ( ) ;
auto place = imperative : : GetCurrentTracer ( ) - > ExpectedPlace ( ) ;
VLOG ( 4 ) < < " Init VarBase from numpy at " < < place ;
InitTensorForVarBase ( self , array , place ) ;
InitTensorForVarBase ( self , array , place ) ;
}
}
@ -1206,15 +1206,44 @@ void BindImperative(py::module *m_ptr) {
if ( py : : isinstance < platform : : CUDAPlace > ( obj ) ) {
if ( py : : isinstance < platform : : CUDAPlace > ( obj ) ) {
auto p = obj . cast < platform : : CUDAPlace * > ( ) ;
auto p = obj . cast < platform : : CUDAPlace * > ( ) ;
self . SetExpectedPlace ( * p ) ;
self . SetExpectedPlace ( * p ) ;
// NOTE(zhiqiu): When switching cuda place, we need to set the
// cuda device id.
// Otherwise, some cuda API may be launched at other cuda place,
// which may cost hundreds of MB of GPU memory due to the cuda
// lib.
# ifdef PADDLE_WITH_CUDA
platform : : SetDeviceId ( p - > device ) ;
# endif
VLOG ( 4 ) < < " Tracer( " < < & self < < " ) "
< < " set expected place " < < * p ;
} else if ( py : : isinstance < platform : : XPUPlace > ( obj ) ) {
} else if ( py : : isinstance < platform : : XPUPlace > ( obj ) ) {
auto p = obj . cast < platform : : XPUPlace * > ( ) ;
auto p = obj . cast < platform : : XPUPlace * > ( ) ;
self . SetExpectedPlace ( * p ) ;
self . SetExpectedPlace ( * p ) ;
VLOG ( 4 ) < < " Tracer( " < < & self < < " ) "
< < " set expected place " < < * p ;
} else if ( py : : isinstance < platform : : CPUPlace > ( obj ) ) {
} else if ( py : : isinstance < platform : : CPUPlace > ( obj ) ) {
auto p = obj . cast < platform : : CPUPlace * > ( ) ;
auto p = obj . cast < platform : : CPUPlace * > ( ) ;
self . SetExpectedPlace ( * p ) ;
self . SetExpectedPlace ( * p ) ;
VLOG ( 4 ) < < " Tracer( " < < & self < < " ) "
< < " set expected place " < < * p ;
} else if ( py : : isinstance < platform : : CUDAPinnedPlace > ( obj ) ) {
} else if ( py : : isinstance < platform : : CUDAPinnedPlace > ( obj ) ) {
auto p = obj . cast < platform : : CUDAPinnedPlace * > ( ) ;
auto p = obj . cast < platform : : CUDAPinnedPlace * > ( ) ;
self . SetExpectedPlace ( * p ) ;
self . SetExpectedPlace ( * p ) ;
VLOG ( 4 ) < < " Tracer( " < < & self < < " ) "
< < " set expected place " < < * p ;
} else if ( py : : isinstance < platform : : Place > ( obj ) ) {
auto p = obj . cast < platform : : Place * > ( ) ;
self . SetExpectedPlace ( * p ) ;
if ( platform : : is_gpu_place ( * p ) ) {
// NOTE(zhiqu): same as obj is CUDAPlace.
# ifdef PADDLE_WITH_CUDA
platform : : SetDeviceId (
BOOST_GET_CONST ( platform : : CUDAPlace , * p ) . device ) ;
# endif
}
VLOG ( 4 ) < < " Tracer( " < < & self < < " ) "
< < " set expected place " < < * p ;
} else {
} else {
PADDLE_THROW ( platform : : errors : : InvalidArgument (
PADDLE_THROW ( platform : : errors : : InvalidArgument (
" Incompatible Place Type: supports XPUPlace, CUDAPlace, "
" Incompatible Place Type: supports XPUPlace, CUDAPlace, "