@ -114,7 +114,9 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) {
}
}
}
}
void BasicEngine : : PrepareGradAccumulators ( const OpBase & op ) {
void BasicEngine : : PrepareGradAccumulators (
const OpBase & op ,
const std : : vector < std : : shared_ptr < GradOpNode > > & grad_pending_nodes ) {
for ( const auto & pair : op . GetOutsMap ( ) ) {
for ( const auto & pair : op . GetOutsMap ( ) ) {
if ( ! pair . second . IsGrad ( ) ) {
if ( ! pair . second . IsGrad ( ) ) {
continue ;
continue ;
@ -123,6 +125,7 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
for ( const auto & var : pair . second ) {
for ( const auto & var : pair . second ) {
if ( ! var ) continue ;
if ( ! var ) continue ;
if ( ! var - > HasGradNode ( ) ) {
auto & accumulator = accumulators_ [ var . get ( ) ] ;
auto & accumulator = accumulators_ [ var . get ( ) ] ;
if ( ! accumulator ) {
if ( ! accumulator ) {
if ( FLAGS_sort_sum_gradient ) {
if ( FLAGS_sort_sum_gradient ) {
@ -135,18 +138,82 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
accumulator - > IncreaseRefCnt ( ) ;
accumulator - > IncreaseRefCnt ( ) ;
VLOG ( 3 ) < < " Prepare to acccumulate variable grad " < < var - > Name ( ) < < " ( "
VLOG ( 3 ) < < " Prepare to acccumulate variable grad " < < var - > Name ( ) < < " ( "
< < var . get ( ) < < " ) with reference count "
< < var . get ( )
< < " ) that don't have grad node with reference count "
< < accumulator - > RefCnt ( ) ;
< < accumulator - > RefCnt ( ) ;
if ( var - > HasLeafHooks ( ) ) {
if ( var - > HasLeafHooks ( ) ) {
VLOG ( 3 ) < < " Grad variable wrapper ( " < < var - > Name ( )
VLOG ( 3 ) < < " Grad variable wrapper ( " < < var - > Name ( )
< < " ) has leaf grad hooks. " ;
< < " ) has leaf grad hooks. " ;
PADDLE_ENFORCE_NE ( var - > HasGradNode ( ) , true ,
PADDLE_ENFORCE_NE (
var - > HasGradNode ( ) , true ,
platform : : errors : : PermissionDenied (
platform : : errors : : PermissionDenied (
" Only leaf Tensor's gradient can append hook to "
" Only leaf Tensor's gradient can append hook to "
" Gradientaccumulator. " ) ) ;
" Gradientaccumulator. " ) ) ;
accumulator - > SetPostHooks ( var - > GetLeafHooks ( ) ) ;
accumulator - > SetPostHooks ( var - > GetLeafHooks ( ) ) ;
}
}
} else {
// Because Inplace op overwrites the grad_node of the input grad_var. So
// only the information of grad_pending_node can be used to find the
// grad_node of grad_var.
bool find_grad_node_of_var = false ;
for ( auto & grad_pending_node : grad_pending_nodes ) {
PADDLE_ENFORCE_NOT_NULL (
grad_pending_node ,
platform : : errors : : NotFound ( " Grad pending node is nullptr. " ) ) ;
for ( auto & grad_pending_op : * grad_pending_node ) {
VLOG ( 6 ) < < " Determine whether var ( " < < var - > Name ( )
< < " ) is the input var of grad_pending_op ( "
< < grad_pending_op . Type ( ) < < " ). " ;
grad_pending_op . EnforceHasInOut ( ) ;
for ( const auto & grad_pending_op_ins_pair :
grad_pending_op . GetInsMap ( ) ) {
if ( ! grad_pending_op_ins_pair . second . IsGrad ( ) ) {
continue ;
}
for ( const auto & pending_in_var :
grad_pending_op_ins_pair . second ) {
if ( var = = pending_in_var ) {
VLOG ( 6 ) < < " Var ( " < < var - > Name ( )
< < " ) is the input var of grad_pending_op ( "
< < grad_pending_op . Type ( ) < < " ). " ;
find_grad_node_of_var = true ;
break ;
}
}
if ( find_grad_node_of_var ) {
break ;
}
}
}
if ( find_grad_node_of_var ) {
auto & accumulator =
accumulators_with_grad_node_ [ grad_pending_node ] [ var . get ( ) ] ;
if ( ! accumulator ) {
if ( FLAGS_sort_sum_gradient ) {
accumulator . reset ( new SortedGradientAccumulator ( var . get ( ) ) ) ;
} else {
accumulator . reset ( new EagerGradientAccumulator ( var . get ( ) ) ) ;
}
}
accumulator - > IncreaseRefCnt ( ) ;
VLOG ( 3 ) < < " Prepare to acccumulate variable grad " < < var - > Name ( )
< < " ( " < < var . get ( )
< < " ) that has grad node with reference count "
< < accumulator - > RefCnt ( ) ;
break ;
}
}
PADDLE_ENFORCE_EQ (
find_grad_node_of_var , true ,
platform : : errors : : NotFound (
" No grad node corresponding to grad Tensor (%s) was found. " ,
var - > Name ( ) ) ) ;
}
}
}
}
}
}
}
@ -154,10 +221,13 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
void BasicEngine : : PrepareDeps ( ) {
void BasicEngine : : PrepareDeps ( ) {
PADDLE_ENFORCE_EQ (
PADDLE_ENFORCE_EQ (
node_deps_ . empty ( ) , true ,
node_deps_ . empty ( ) , true ,
platform : : errors : : AlreadyExists ( " Op deps must be initialized here " ) ) ;
platform : : errors : : AlreadyExists ( " Op deps must be initialized . " ) ) ;
PADDLE_ENFORCE_EQ (
PADDLE_ENFORCE_EQ (
accumulators_ . empty ( ) , true ,
accumulators_ . empty ( ) , true ,
platform : : errors : : AlreadyExists ( " Accumulators must be initialized here " ) ) ;
platform : : errors : : AlreadyExists ( " Accumulators must be initialized. " ) ) ;
PADDLE_ENFORCE_EQ (
accumulators_with_grad_node_ . empty ( ) , true ,
platform : : errors : : AlreadyExists ( " Accumulators must be initialized. " ) ) ;
std : : queue < GradOpNode * > q ;
std : : queue < GradOpNode * > q ;
std : : unordered_set < GradOpNode * > visited ;
std : : unordered_set < GradOpNode * > visited ;
@ -169,16 +239,17 @@ void BasicEngine::PrepareDeps() {
auto * cur_node = q . front ( ) ;
auto * cur_node = q . front ( ) ;
q . pop ( ) ;
q . pop ( ) ;
const auto & grad_pending_nodes = cur_node - > GradPendingNodes ( ) ;
for ( auto & cur_op : * cur_node ) {
for ( auto & cur_op : * cur_node ) {
cur_op . EnforceHasInOut ( ) ;
cur_op . EnforceHasInOut ( ) ;
PrepareGradAccumulators ( cur_op );
PrepareGradAccumulators ( cur_op , grad_pending_nodes );
}
}
const auto & grad_pending_nodes = cur_node - > GradPendingNodes ( ) ;
for ( auto & grad_pending_node : grad_pending_nodes ) {
for ( auto & grad_pending_node : grad_pending_nodes ) {
PADDLE_ENFORCE_NOT_NULL (
PADDLE_ENFORCE_NOT_NULL (
grad_pending_node ,
grad_pending_node ,
platform : : errors : : NotFound ( " Grad pending node should not be null" ) ) ;
platform : : errors : : NotFound ( " Grad pending node i s nullptr. " ) ) ;
+ + node_deps_ [ grad_pending_node . get ( ) ] ;
+ + node_deps_ [ grad_pending_node . get ( ) ] ;
if ( visited . count ( grad_pending_node . get ( ) ) = = 0 ) {
if ( visited . count ( grad_pending_node . get ( ) ) = = 0 ) {
visited . insert ( grad_pending_node . get ( ) ) ;
visited . insert ( grad_pending_node . get ( ) ) ;
@ -204,6 +275,8 @@ void BasicEngine::Execute() {
auto shared_cur_node = std : : move ( q . front ( ) ) ;
auto shared_cur_node = std : : move ( q . front ( ) ) ;
q . pop ( ) ;
q . pop ( ) ;
auto & inplace_grad_name_map = shared_cur_node - > InplaceGradNameMap ( ) ;
for ( auto & cur_op : * shared_cur_node ) {
for ( auto & cur_op : * shared_cur_node ) {
+ + op_num ;
+ + op_num ;
@ -228,11 +301,38 @@ void BasicEngine::Execute() {
continue ;
continue ;
}
}
auto iter = accumulators_ . find ( var . get ( ) ) ;
std : : unordered_map < VariableWrapper * ,
std : : unique_ptr < GradientAccumulator > > : : iterator
iter ;
if ( ! var - > HasGradNode ( ) ) {
VLOG ( 10 ) < < " Find gradient of var ( " < < var - > Name ( )
< < " ) with no grad_node. " ;
iter = accumulators_ . find ( var . get ( ) ) ;
PADDLE_ENFORCE_EQ (
PADDLE_ENFORCE_EQ (
iter ! = accumulators_ . end ( ) , true ,
iter ! = accumulators_ . end ( ) , true ,
platform : : errors : : NotFound ( " Cannot find gradient of variable %s " ,
platform : : errors : : NotFound (
var - > Name ( ) ) ) ;
" Cannot find gradient of variable %s " , var - > Name ( ) ) ) ;
} else {
bool flag_find_grad = false ;
VLOG ( 10 ) < < " Find gradient of var ( " < < var - > Name ( )
< < " ) with grad_node. " ;
for ( auto & grad_pending_node :
shared_cur_node - > GradPendingNodes ( ) ) {
const auto & iter_grad_node =
accumulators_with_grad_node_ . find ( grad_pending_node ) ;
if ( iter_grad_node ! = accumulators_with_grad_node_ . end ( ) ) {
iter = iter_grad_node - > second . find ( var . get ( ) ) ;
if ( iter ! = iter_grad_node - > second . end ( ) ) {
flag_find_grad = true ;
break ;
}
}
}
PADDLE_ENFORCE_EQ (
flag_find_grad , true ,
platform : : errors : : NotFound (
" Cannot find gradient of variable %s " , var - > Name ( ) ) ) ;
}
// leaf_accumulators_ : hooks and accumulate-grad for leaf tensor
// leaf_accumulators_ : hooks and accumulate-grad for leaf tensor
if ( var - > IsLeafGrad ( ) ) {
if ( var - > IsLeafGrad ( ) ) {
@ -251,6 +351,25 @@ void BasicEngine::Execute() {
need_accu_var_list_ . emplace_back ( iter - > second . get ( ) , var ) ;
need_accu_var_list_ . emplace_back ( iter - > second . get ( ) , var ) ;
VLOG ( 10 ) < < " create temporary var of " < < var - > Name ( )
VLOG ( 10 ) < < " create temporary var of " < < var - > Name ( )
< < " for sum gradient within this graph! " ;
< < " for sum gradient within this graph! " ;
} else if ( ! inplace_grad_name_map . empty ( ) & &
inplace_grad_name_map . count ( pair . first ) ) {
// When calculate Inplace grad op, create a new output var.
// If a tmp var has been created, there is no need to create it
// again.
for ( auto & in_var :
bwd_ins . at ( inplace_grad_name_map . at ( pair . first ) ) ) {
if ( in_var = = var ) {
auto tmp_var = std : : make_shared < VariableWrapper > ( var - > Name ( ) ) ;
tmp_var - > SetType ( var - > Type ( ) ) ;
tmp_var - > SetForwardDataType ( var - > ForwardDataType ( ) ) ;
inplace_output_grad_var_list_ . emplace_back ( var , tmp_var ) ;
var = tmp_var ;
VLOG ( 10 ) < < " Inplace grad op does not use the Inplace "
" strategy, a temporary output var ( "
< < var - > Name ( ) < < " ) will be created. " ;
break ;
}
}
}
}
}
}
}
}
@ -286,6 +405,10 @@ void BasicEngine::Execute() {
cur_op . place ( ) ) ;
cur_op . place ( ) ) ;
}
}
for ( auto & pair : inplace_output_grad_var_list_ ) {
* pair . first = std : : move ( * pair . second ) ;
}
// Step 2: Sum Gradient of This graph
// Step 2: Sum Gradient of This graph
for ( auto & pair : need_accu_var_list_ ) {
for ( auto & pair : need_accu_var_list_ ) {
pair . first - > SumGrad ( std : : move ( pair . second ) , cur_op . id ( ) ) ;
pair . first - > SumGrad ( std : : move ( pair . second ) , cur_op . id ( ) ) ;
@ -308,6 +431,7 @@ void BasicEngine::Execute() {
}
}
need_accu_var_list_ . clear ( ) ;
need_accu_var_list_ . clear ( ) ;
inplace_output_grad_var_list_ . clear ( ) ;
leaf_accumulators_ . clear ( ) ;
leaf_accumulators_ . clear ( ) ;
if ( ! retain_graph_ ) {
if ( ! retain_graph_ ) {
@ -318,9 +442,9 @@ void BasicEngine::Execute() {
// Step 3: Collect ready ops
// Step 3: Collect ready ops
for ( auto & grad_pending_node : shared_cur_node - > GradPendingNodes ( ) ) {
for ( auto & grad_pending_node : shared_cur_node - > GradPendingNodes ( ) ) {
PADDLE_ENFORCE_NOT_NULL ( grad_pending_node ,
PADDLE_ENFORCE_NOT_NULL (
platform : : errors : : NotFound (
grad_pending_node ,
" Grad pending node should not be nullptr " ) ) ;
platform : : errors : : NotFound ( " Grad pending node is nullptr. " ) ) ;
auto iter = node_deps_ . find ( grad_pending_node . get ( ) ) ;
auto iter = node_deps_ . find ( grad_pending_node . get ( ) ) ;
if ( iter = = node_deps_ . end ( ) ) {
if ( iter = = node_deps_ . end ( ) ) {
continue ;
continue ;
@ -340,6 +464,7 @@ void BasicEngine::Clear() {
init_node_ . reset ( ) ;
init_node_ . reset ( ) ;
node_deps_ . clear ( ) ;
node_deps_ . clear ( ) ;
accumulators_ . clear ( ) ;
accumulators_ . clear ( ) ;
accumulators_with_grad_node_ . clear ( ) ;
need_accu_var_list_ . clear ( ) ;
need_accu_var_list_ . clear ( ) ;
leaf_accumulators_ . clear ( ) ;
leaf_accumulators_ . clear ( ) ;
}
}