@ -563,6 +563,33 @@ void BindImperative(py::module *m_ptr) {
. def ( " __init__ " , & InitVarBaseFromNumpyWithArgDefault , py : : arg ( " value " ) )
. def ( " __init__ " , & InitVarBaseFromNumpyWithArgDefault , py : : arg ( " value " ) )
. def ( " __init__ " , & InitVarBaseFromTensorWithArgDefault , py : : arg ( " tensor " ) )
. def ( " __init__ " , & InitVarBaseFromTensorWithArgDefault , py : : arg ( " tensor " ) )
. def ( " __init__ " , & InitVarBaseFromNumpyWithKwargs )
. def ( " __init__ " , & InitVarBaseFromNumpyWithKwargs )
. def ( " __setitem__ " ,
[ ] ( std : : shared_ptr < imperative : : VarBase > & self , py : : handle _index ,
py : : object & value_obj ) {
auto self_tensor =
self - > MutableVar ( ) - > GetMutable < framework : : LoDTensor > ( ) ;
auto self_numpy = TensorToPyArray ( * self_tensor ) ;
if ( py : : isinstance < py : : array > ( value_obj ) | |
py : : isinstance < py : : int_ > ( value_obj ) | |
py : : isinstance < py : : float_ > ( value_obj ) ) {
auto value_numpy = value_obj ;
self_numpy [ _index ] = value_numpy ;
SetTensorFromPyArray ( self_tensor , self_numpy ,
self_tensor - > place ( ) , true ) ;
} else {
auto value =
value_obj . cast < std : : shared_ptr < imperative : : VarBase > > ( ) ;
auto value_tensor =
value - > MutableVar ( ) - > GetMutable < framework : : LoDTensor > ( ) ;
auto value_numpy = TensorToPyArray ( * value_tensor ) ;
self_numpy [ _index ] = value_numpy ;
SetTensorFromPyArray ( self_tensor , self_numpy ,
self_tensor - > place ( ) , true ) ;
}
} )
. def ( " __getitem__ " ,
. def ( " __getitem__ " ,
[ ] ( std : : shared_ptr < imperative : : VarBase > & self , py : : handle _index ) {
[ ] ( std : : shared_ptr < imperative : : VarBase > & self , py : : handle _index ) {
std : : vector < int > slice_axes , slice_starts , slice_ends ,
std : : vector < int > slice_axes , slice_starts , slice_ends ,
@ -797,7 +824,8 @@ void BindImperative(py::module *m_ptr) {
return framework : : vectorize < int > (
return framework : : vectorize < int > (
self . Var ( ) . Get < framework : : SelectedRows > ( ) . value ( ) . dims ( ) ) ;
self . Var ( ) . Get < framework : : SelectedRows > ( ) . value ( ) . dims ( ) ) ;
} else {
} else {
VLOG ( 2 ) < < " It is meaningless to get shape of variable type "
VLOG ( 2 ) < < " It is meaningless to get shape of "
" variable type "
< < GetTypeName ( self ) ;
< < GetTypeName ( self ) ;
return std : : vector < int > ( ) ;
return std : : vector < int > ( ) ;
}
}