@ -583,26 +583,82 @@ void BindImperative(py::module *m_ptr) {
py : : object & value_obj ) {
auto self_tensor =
self - > MutableVar ( ) - > GetMutable < framework : : LoDTensor > ( ) ;
auto self_numpy = TensorToPyArray ( * self_tensor ) ;
PyObject * index_ptr = ! PyTuple_Check ( _index . ptr ( ) )
? PyTuple_Pack ( 1 , _index . ptr ( ) )
: _index . ptr ( ) ;
// 1. Check argumnets
// 1.1 Check whether _index can be parsed.
bool parse_index = true ;
const int size = PyTuple_GET_SIZE ( index_ptr ) ;
for ( int dim = 0 ; dim < size ; + + dim ) {
PyObject * slice_item = PyTuple_GetItem ( index_ptr , dim ) ;
if ( ! ( PyCheckInteger ( slice_item ) | | PySlice_Check ( slice_item ) ) ) {
parse_index = false ;
break ;
}
}
// 1.2 Check whether stride is 1.
std : : vector < int > axes , starts , ends , strides , decrease_axis ,
infer_flags ;
bool stride_is_1 = true ;
if ( parse_index ) {
ParseIndexingSlice ( self_tensor , index_ptr , & axes , & starts , & ends ,
& strides , & decrease_axis , & infer_flags ) ;
stride_is_1 =
std : : all_of ( strides . cbegin ( ) , strides . cend ( ) ,
[ ] ( int64_t stride ) { return stride = = 1 ; } ) ;
}
// 1.3 Check whether value obj is a tensor.
bool value_is_tensor = true ;
if ( py : : isinstance < py : : array > ( value_obj ) | |
py : : isinstance < py : : int_ > ( value_obj ) | |
py : : isinstance < py : : float_ > ( value_obj ) ) {
auto value_numpy = value_obj ;
self_numpy [ _index ] = value_numpy ;
SetTensorFromPyArray ( self_tensor , self_numpy ,
self_tensor - > place ( ) , true ) ;
value_is_tensor = false ;
}
// 2. Call op set_value to speed up if the condition is met,
// otherwise call TensorToPyArray.
// TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance.
if ( parse_index & & stride_is_1 & & value_is_tensor ) {
framework : : AttributeMap attrs = {
{ " axes " , axes } , { " starts " , starts } , { " ends " , ends } } ;
imperative : : NameVarBaseMap ins = { { " Input " , { self } } } ;
imperative : : NameVarBaseMap outs = { { " Out " , { self } } } ;
} else {
auto value =
value_obj . cast < std : : shared_ptr < imperative : : VarBase > > ( ) ;
auto value_tensor =
value - > MutableVar ( ) - > GetMutable < framework : : LoDTensor > ( ) ;
auto value_numpy = TensorToPyArray ( * value_tensor ) ;
value_obj . cast < std : : shared_ptr < imperative : : VarBase > > ( ) ;
ins . insert ( { " ValueTensor " , { value_tensor } } ) ;
self_numpy [ _index ] = value_numpy ;
SetTensorFromPyArray ( self_tensor , self_numpy ,
self_tensor - > place ( ) , true ) ;
const auto & tracer = imperative : : GetCurrentTracer ( ) ;
{
// Release gil and do tracing
py : : gil_scoped_release release ;
tracer - > TraceOp ( " set_value " , ins , outs , std : : move ( attrs ) ) ;
}
} else {
auto self_numpy = TensorToPyArray ( * self_tensor ) ;
if ( value_is_tensor ) {
auto value =
value_obj . cast < std : : shared_ptr < imperative : : VarBase > > ( ) ;
auto value_tensor =
value - > MutableVar ( ) - > GetMutable < framework : : LoDTensor > ( ) ;
auto value_numpy = TensorToPyArray ( * value_tensor ) ;
self_numpy [ _index ] = value_numpy ;
SetTensorFromPyArray ( self_tensor , self_numpy ,
self_tensor - > place ( ) , true ) ;
} else {
auto value_numpy = value_obj ;
self_numpy [ _index ] = value_numpy ;
SetTensorFromPyArray ( self_tensor , self_numpy ,
self_tensor - > place ( ) , true ) ;
}
}
// NOTE(liym27):
// Increase the version of VarBase self because __setitem__ is an