@ -38,7 +38,20 @@ namespace imperative {
void BasicEngine : : Init ( VarBase * var , bool retain_graph ) {
retain_graph_ = retain_graph ;
init_node_ = var - > GradVarBase ( ) - > GradNode ( ) ;
var - > GradVarBase ( ) - > ClearGradNode ( ) ;
PADDLE_ENFORCE_EQ ( var - > GradVarBase ( ) - > GraphIsFreed ( ) , false ,
platform : : errors : : Unavailable (
" %s trying to backward through the same graph a second "
" time, but this graph have already been freed. Please "
" specify Tensor.backward(retain_graph=True) when "
" calling backward at the first time. " ,
var - > Name ( ) ) ) ;
if ( ! retain_graph ) {
VLOG ( 5 ) < < " Clear the auto-grad graph from grad var " < < var - > Name ( )
< < " because of retain_graph=False when calling backward " ;
var - > GradVarBase ( ) - > SetGraphIsFreed ( true ) ;
var - > GradVarBase ( ) - > ClearGradNode ( ) ;
}
if ( init_node_ = = nullptr | | var - > OverridedStopGradient ( ) ) {
VLOG ( 3 ) < < " Skip auto grad since there is no grad op for var or loss is "
@ -47,7 +60,7 @@ void BasicEngine::Init(VarBase* var, bool retain_graph) {
return ;
}
VLOG ( 3 ) < < " start backward" ;
VLOG ( 3 ) < < " Init first node of backward" ;
PADDLE_ENFORCE_EQ (
var - > HasGradVar ( ) , true ,
@ -114,6 +127,10 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
accumulator - > IncreaseRefCnt ( ) ;
VLOG ( 3 ) < < " Prepare to acccumulate variable grad " < < var - > Name ( ) < < " ( "
< < var . get ( ) < < " ) with reference count "
< < accumulator - > RefCnt ( ) ;
if ( var - > HasLeafHooks ( ) ) {
VLOG ( 3 ) < < " Grad variable wrapper ( " < < var - > Name ( )
< < " ) has leaf grad hooks. " ;
@ -123,10 +140,6 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
" Gradientaccumulator. " ) ) ;
accumulator - > SetPostHooks ( var - > GetLeafHooks ( ) ) ;
}
VLOG ( 3 ) < < " Prepare to acccumulate variable grad " < < var - > Name ( ) < < " ( "
< < var . get ( ) < < " ) with reference count "
< < accumulator - > RefCnt ( ) ;
}
}
}
@ -190,13 +203,14 @@ void BasicEngine::Execute() {
// CheckBackWardInput
CheckBackwardInputs ( cur_op ) ;
// Step 1: Run Backward
// Step 1: Run Backward OP
auto & bwd_ins = cur_op . GetInsMap ( ) ;
auto & bwd_outs = cur_op . GetOutsMap ( ) ;
NameVarMap < VariableWrapper > tmp_outs ( bwd_outs ) ;
// 1. construct the output map 2. replace the element in the map
// A var may be coresponding to several grad var in one op
// 1. construct the temp output map, avoid to disrupt graph
// 2. replace the element in the map by temp var, because a
// var may be coresponding to several grad var in one op
for ( auto & pair : tmp_outs ) {
if ( ! pair . second . IsGrad ( ) ) {
continue ;
@ -213,15 +227,23 @@ void BasicEngine::Execute() {
platform : : errors : : NotFound ( " Cannot find gradient of variable %s " ,
var - > Name ( ) ) ) ;
if ( ! var - > OverridedStopGradient ( ) & & iter - > second - > RefCnt ( ) = = 1 ) {
no_need_run_accumulators_ . emplace_back ( iter - > second . get ( ) ) ;
continue ;
// leaf_accumulators_ : hooks and accumulate-grad for leaf tensor
if ( var - > IsLeafGrad ( ) ) {
leaf_accumulators_ . insert ( iter - > second . get ( ) ) ;
if ( iter - > second - > HasInnerVar ( ) ) {
var = iter - > second - > InnerVar ( ) ;
}
}
auto tmp_var = std : : make_shared < VariableWrapper > ( var - > Name ( ) ) ;
tmp_var - > SetType ( var - > Type ( ) ) ;
var = tmp_var ;
need_accu_var_list_ . emplace_back ( iter - > second . get ( ) , var ) ;
if ( var - > OverridedStopGradient ( ) | | iter - > second - > RefCnt ( ) > 1 ) {
auto tmp_var = std : : make_shared < VariableWrapper > ( var - > Name ( ) ) ;
tmp_var - > SetType ( var - > Type ( ) ) ;
var = tmp_var ;
need_accu_var_list_ . emplace_back ( iter - > second . get ( ) , var ) ;
VLOG ( 10 ) < < " create temporary var of " < < var - > Name ( )
< < " for sum gradient within this graph! " ;
}
}
}
@ -256,22 +278,32 @@ void BasicEngine::Execute() {
cur_op . place ( ) ) ;
}
// Step 2: Sum Gradient & Call Accumulator Hooks
for ( auto * accumulator : no_need_run_accumulators_ ) {
// Step 2: Sum Gradient of This graph
for ( auto & pair : need_accu_var_list_ ) {
pair . first - > SumGrad ( std : : move ( pair . second ) , cur_op . id ( ) ) ;
}
// Step 3: Call Hooks && Sum Gradient with Pre-Graph && Call BackwardHooks
for ( auto * accumulator : leaf_accumulators_ ) {
if ( ! accumulator - > SumGradCompleted ( ) ) {
continue ;
}
// 1. Call Hooks for **inner_var_**
// 2. Sum Gradient with Previous Graph
accumulator - > AccumulateGrad ( ) ;
// 3. Call backward Hooks for **var_**
if ( accumulator - > HasPostHooks ( ) ) {
accumulator - > CallBackwardPostHooks ( ) ;
}
}
for ( auto & pair : need_accu_var_list_ ) {
pair . first - > Add ( std : : move ( pair . second ) , cur_op . id ( ) ) ;
}
need_accu_var_list_ . clear ( ) ;
no_need_run_accumulators_ . clear ( ) ;
leaf_accumulators_ . clear ( ) ;
VLOG ( 3 ) < < " Remove op after op " < < cur_op . Type ( ) < < " runs " ;
if ( ! retain_graph_ ) {
VLOG ( 3 ) < < " Remove op after op " < < cur_op . Type ( ) < < " runs " ;
cur_op . ClearBackwardTrace ( ) ;
}
}
@ -301,7 +333,7 @@ void BasicEngine::Clear() {
node_deps_ . clear ( ) ;
accumulators_ . clear ( ) ;
need_accu_var_list_ . clear ( ) ;
no_need_run _accumulators_. clear ( ) ;
leaf _accumulators_. clear ( ) ;
}
} // namespace imperative