@ -79,12 +79,16 @@ class TensorAddToFunctor : public boost::static_visitor<> {
} // namespace detail
void AddTo ( std : : shared_ptr < VarBase > src , std : : shared_ptr < VarBase > dst ,
platform : : Place place ) {
if ( ! dst - > IsInitialize ( ) ) {
VLOG ( 2 ) < < " im here1 " ;
platform : : Place place , GradientRef * grad_ref ) {
PADDLE_ENFORCE ( grad_ref - > find ( dst . get ( ) ) ! = grad_ref - > end ( ) ,
" gradient %s are not found in grad_ref " , dst - > Name ( ) ) ;
if ( ( * grad_ref ) [ dst . get ( ) ] . second ) {
PADDLE_ENFORCE ( src - > IsInitialize ( ) , " Using uninitialized VarBase " ) ;
dst - > var_ = std : : move ( src - > var_ ) ;
dst - > SetInitialize ( true ) ;
( * grad_ref ) [ dst . get ( ) ] . second = false ;
if ( ! dst - > IsInitialize ( ) ) {
dst - > SetInitialize ( true ) ;
}
return ;
} else {
framework : : Tensor * dst_tensor =
@ -118,7 +122,8 @@ void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb,
}
void AddGradBySort ( BackwardSumMap * bck_map ,
std : : shared_ptr < imperative : : VarBase > target ) {
std : : shared_ptr < imperative : : VarBase > target ,
GradientRef * grad_ref ) {
PADDLE_ENFORCE ( bck_map - > find ( target . get ( ) ) ! = bck_map - > end ( ) ,
" Can't find %s in backward grad map " , target - > Name ( ) ) ;
std : : pair < platform : : Place ,
@ -133,7 +138,7 @@ void AddGradBySort(BackwardSumMap* bck_map,
VLOG ( 10 ) < < " add origin_grad: " < < target - > Name ( ) ;
VLOG ( 10 ) < < " added grad: " < < var_pair . second - > Name ( )
< < " trace id is: " < < var_pair . first ;
AddTo ( var_pair . second , target , current . first );
AddTo ( var_pair . second , target , current . first , grad_ref );
var_pair . second . reset ( ) ;
}
}
@ -148,7 +153,6 @@ class Autograd {
}
VLOG ( 2 ) < < " start autograd " ;
BackwardSumMap bck_map ;
GradientRef grad_ref ;
std : : deque < OpBase * > ready ;
ready . push_back ( var - > PreOp ( ) ) ;
@ -200,12 +204,14 @@ class Autograd {
while ( ! queue . empty ( ) ) {
OpBase * candidate = queue . front ( ) ;
queue . pop_front ( ) ;
if ( bck_stratedy . sorted_sum_gradient _) {
for ( const auto & map : candidate - > grad_output_vars_ ) {
for ( const auto & it : map ) {
for ( const auto & vb : it . second ) {
+ + ( * grad_ref ) [ vb . get ( ) ] ;
for ( const auto & map : candidate - > grad_output_vars _) {
for ( const auto & it : map ) {
for ( const auto & vb : it . second ) {
if ( bck_stratedy . sorted_sum_gradient_ ) {
+ + ( * grad_ref ) [ vb . get ( ) ] .first ;
}
// init the state of the grad_
( * grad_ref ) [ vb . get ( ) ] . second = true ;
}
}
}
@ -225,6 +231,8 @@ class Autograd {
}
return ret ;
}
GradientRef grad_ref ;
} ;
std : : unique_ptr < VarBase > VarBase : : NewVarBase ( const platform : : Place & dst_place ,
@ -382,21 +390,21 @@ std::vector<VarBasePtrMap> OpBase::ApplyGrad(
grad_ref - > find ( origin_outputs [ i ] . get ( ) ) ! = grad_ref - > end ( ) ,
" Can't find %s in grad_reference count map " ,
origin_outputs [ i ] - > Name ( ) ) ;
PADDLE_ENFORCE ( grad_ref - > at ( origin_outputs [ i ] . get ( ) ) > = 1 ,
PADDLE_ENFORCE ( grad_ref - > at ( origin_outputs [ i ] . get ( ) ) . first > = 1 ,
" Backward error when calculate grad reference " ) ;
if ( grad_ref - > at ( origin_outputs [ i ] . get ( ) ) > 1 ) {
if ( grad_ref - > at ( origin_outputs [ i ] . get ( ) ) . first > 1 ) {
VLOG ( 10 ) < < " remove ref for " < < origin_outputs [ i ] - > Name ( ) ;
grad_ref - > at ( origin_outputs [ i ] . get ( ) ) - - ;
grad_ref - > at ( origin_outputs [ i ] . get ( ) ) . first - - ;
} else {
VLOG ( 10 ) < < " Add grad for: " < < origin_outputs [ i ] - > Name ( ) ;
AddGradBySort ( bck_map , origin_outputs [ i ] );
grad_ref - > at ( origin_outputs [ i ] . get ( ) ) - - ;
AddGradBySort ( bck_map , origin_outputs [ i ] , grad_ref );
grad_ref - > at ( origin_outputs [ i ] . get ( ) ) . first - - ;
}
} else {
VLOG ( 10 ) < < " AddTo Called with orig_grad is: "
< < origin_outputs [ i ] - > name_ < < " Grad to be added is "
< < outputs [ i ] - > name_ ;
AddTo ( outputs [ i ] , origin_outputs [ i ] , place_ );
AddTo ( outputs [ i ] , origin_outputs [ i ] , place_ , grad_ref );
outputs [ i ] . reset ( ) ;
}
}