@ -696,10 +696,9 @@ void BindImperative(py::module *m_ptr) {
x = linear ( data )
print ( x . numpy ( ) )
) DOC " )
. def (
" detach " ,
[ ] ( const imperative : : VarBase & self )
- > std : : shared_ptr < imperative : : VarBase > {
. def ( " detach " ,
[ ] ( const imperative : : VarBase
& self ) - > std : : shared_ptr < imperative : : VarBase > {
PADDLE_ENFORCE_EQ (
self . Var ( ) . IsInitialized ( ) , true ,
platform : : errors : : InvalidArgument (
@ -720,15 +719,41 @@ void BindImperative(py::module *m_ptr) {
detach_var - > SetType ( self . Type ( ) ) ;
detach_var - > SetDataType ( self . DataType ( ) ) ;
// NOTE(liym27):
// Call Variable::SharePlaceholderWith but not
// Tensor::ShareDataWith or Tensor::ShareBufferWith, because
// `detach_var` should share the same TensorInplaceVersion with
// `self`, and only SharePlaceholderWith can also share the same
// TensorInplaceVersion, which is used to check whether inplace
if ( self . Var ( ) . IsType < framework : : LoDTensor > ( ) ) {
const auto & origin_tensor =
self . Var ( ) . Get < framework : : LoDTensor > ( ) ;
PADDLE_ENFORCE_EQ (
origin_tensor . IsInitialized ( ) , true ,
platform : : errors : : InvalidArgument (
" Tensor %s has not been initialized! " , self . Name ( ) ) ) ;
auto * detach_tensor =
detach_var - > MutableVar ( ) - > GetMutable < framework : : LoDTensor > ( ) ;
detach_tensor - > ShareDataWith ( origin_tensor ) ;
// NOTE(liym27): Call ShareInplaceVersionCounterWith to share the
// same TensorInplaceVersion, which is used to check whether
// inplace
// operations are correct.
detach_var - > MutableVar ( ) - > SharePlaceholderWith ( self . Var ( ) ) ;
detach_tensor - > ShareInplaceVersionCounterWith ( origin_tensor ) ;
} else {
const auto & origin_selected_rows =
self . Var ( ) . Get < framework : : SelectedRows > ( ) ;
PADDLE_ENFORCE_EQ (
origin_selected_rows . value ( ) . IsInitialized ( ) , true ,
platform : : errors : : InvalidArgument (
" Tensor %s has not been initialized! " , self . Name ( ) ) ) ;
auto * detach_selected_rows =
detach_var - > MutableVar ( )
- > GetMutable < framework : : SelectedRows > ( ) ;
detach_selected_rows - > set_height ( origin_selected_rows . height ( ) ) ;
detach_selected_rows - > set_rows ( origin_selected_rows . rows ( ) ) ;
detach_selected_rows - > mutable_value ( ) - > ShareDataWith (
origin_selected_rows . value ( ) ) ;
detach_selected_rows - > mutable_value ( )
- > ShareInplaceVersionCounterWith (
origin_selected_rows . value ( ) ) ;
}
VLOG ( 3 ) < < " The detached Tensor( " < < detach_var - > Name ( )
< < " ) share data with " < < self . Name ( ) ;
return detach_var ;