@ -32,6 +32,16 @@ static std::vector<std::unique_ptr<framework::OpDesc>> CreateGradOpDescs(
}
}
static void PassStopGradient ( const NameVarBaseMap & outs , bool generate_grad ) {
for ( const auto & name_pair : outs ) {
for ( const auto & vb : name_pair . second ) {
VLOG ( 6 ) < < " Set output: " < < vb - > Name ( ) < < " 's OverridedStopGradient as "
< < generate_grad ;
vb - > InnerSetOverridedStopGradient ( generate_grad ) ;
}
}
}
void Tracer : : TraceOp ( const std : : string & type , const NameVarBaseMap & ins ,
const NameVarBaseMap & outs , framework : : AttributeMap attrs ,
const platform : : Place & place , bool trace_backward ) {
@ -45,16 +55,27 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
TraceBackward ( op , framework : : OpDesc ( op - > Type ( ) , op - > InputNameMap ( ) ,
op - > OutputNameMap ( ) , op - > Attrs ( ) ) ,
ins , outs ) ;
VLOG ( 6 ) < < " Finish tracking Backward of op: " < < type ;
} else {
VLOG ( 3 ) < < " No Grad to track for Op: " < < type ;
}
VLOG ( 6 ) < < " Finish tracing fwd op: " < < type ;
}
bool Tracer : : ComputeRequiredGrad ( const NameVarBaseMap & ins ,
const NameVarBaseMap & outs ,
bool trace_backward ) {
// TODO(jiabin): Implement auto prune here
return trace_backward ;
if ( ! trace_backward ) return false ;
for ( const auto & name_pair : ins ) {
for ( const auto & var_base : name_pair . second ) {
if ( ! var_base - > OverridedStopGradient ( ) ) {
VLOG ( 6 ) < < " Find out input: " < < var_base - > Name ( )
< < " 's GeneratedGrad is True " ;
PassStopGradient ( outs , var_base - > OverridedStopGradient ( ) ) ;
return true ;
}
}
}
return false ;
}
void Tracer : : TraceBackward ( const std : : shared_ptr < OpBase > & fwd_op ,
@ -133,14 +154,25 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
PADDLE_ENFORCE_EQ ( fwd_var_iter ! = name_to_var . end ( ) , true ,
" Cannot find forward variable named %s " ,
fwd_var_name ) ;
const auto & tmp = ( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) ;
PADDLE_ENFORCE_NOT_NULL (
( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) ,
tmp . get ( ) ,
" Grad of %s should "
" not be NULL when we Track_Backward Input of %s " ,
( * ( fwd_var_iter - > second ) ) - > Name ( ) , grad_op - > Type ( ) ) ;
( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) - > AddGradOps ( grad_op ) ;
// Create grad_in's dim in tensor for Grad Dependency compute
auto * tensor = tmp - > MutableVar ( ) - > GetMutable < framework : : LoDTensor > ( ) ;
tensor - > Resize ( ( * ( fwd_var_iter - > second ) )
- > Var ( )
. Get < framework : : LoDTensor > ( )
. dims ( ) ) ;
// Add Grad Op for grad_in
tmp - > AddGradOps ( grad_op ) ;
VLOG ( 3 ) < < " Add Grad Op " < < grad_op - > Type ( ) < < " for : "
< < ( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) - > Name ( ) ;
// Add Grad var input to engine set
engine_ - > InsertGradVar ( tmp . get ( ) ) ;
VLOG ( 3 ) < < " Add Grad: " < < tmp - > Name ( ) < < " in to Engine " ;
bwd_in . emplace_back ( ( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) ) ;
} else {
// If it is a forward var, just add it
@ -150,8 +182,7 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
grad_in_var_name ) ;
bwd_in . emplace_back ( * ( fwd_var_iter - > second ) ) ;
}
VLOG ( 3 ) < < " Set backward input " < < grad_ins . first < < " of "
VLOG ( 3 ) < < " Set backward input from fwd var " < < grad_ins . first < < " of "
< < grad_op - > Type ( ) < < " to be "
< < ( bwd_in . back ( ) ? bwd_in . back ( ) - > Name ( ) : " nullptr " ) ;
}
@ -173,40 +204,44 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
PADDLE_ENFORCE_EQ ( fwd_var_iter ! = name_to_var . end ( ) , true ,
" Cannot find forward variable named %s " ,
iter - > second ) ;
PADDLE_ENFORCE_NOT_NULL (
( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) ,
" Grad of %s should "
" not be NULL when we Track_Backward Output of %s " ,
( * ( fwd_var_iter - > second ) ) - > Name ( ) , grad_op - > Type ( ) ) ;
bwd_out . emplace_back ( ( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) ) ;
VLOG ( 3 ) < < " Set backward output " < < grad_outs . first < < " of "
< < grad_op - > Type ( ) < < " to be "
< < ( bwd_out . back ( ) ? bwd_out . back ( ) - > Name ( ) : " nullptr " ) ;
auto preceding_ops =
( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) - > GradOps ( ) ;
if ( VLOG_IS_ON ( 3 ) & & ! preceding_ops . empty ( ) ) {
VLOG ( 3 ) < < " Add preceding Op of : "
< < ( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) - > Name ( )
< < " It's preceding Op are: " ;
for ( const auto & op : preceding_ops ) {
VLOG ( 3 ) < < op - > Type ( ) ;
const auto & tmp = ( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) ;
PADDLE_ENFORCE_NOT_NULL ( tmp . get ( ) ,
" Grad output: %s of op: %s should not be NULL " ,
( tmp - > Name ( ) , grad_op - > Type ( ) ) ) ;
if ( ( ! tmp - > OverridedStopGradient ( ) ) | | ( grad_outs . second . size ( ) > 1 ) ) {
VLOG ( 3 ) < < " Set backward output " < < grad_outs . first < < " of "
< < grad_op - > Type ( ) < < " to be " < < tmp - > Name ( )
< < " . Its Overrided Stop_Gradient is: False " ;
bwd_out . emplace_back ( tmp ) ;
auto grad_pending_ops =
( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) - > GradOps ( ) ;
if ( VLOG_IS_ON ( 3 ) & & ! grad_pending_ops . empty ( ) ) {
VLOG ( 3 ) < < " Add grad_pending Op of : "
< < ( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) - > Name ( )
< < " It's grad_pending Op are: " ;
for ( const auto & op : grad_pending_ops ) {
VLOG ( 3 ) < < op - > Type ( ) ;
}
}
}
if ( ! preceding_ops . empty ( ) ) {
for ( const auto & op : preceding_ops ) {
PADDLE_ENFORCE_NOT_NULL ( op , " No nullptr should be preceding_op " ) ;
if ( visited_preceding_ops . count ( op ) = = 0 ) {
visited_preceding_ops . insert ( op ) ;
grad_op - > InsertGradPendingOps ( op ) ;
if ( ! grad_pending_ops . empty ( ) ) {
for ( const auto & op : grad_pending_ops ) {
PADDLE_ENFORCE_NOT_NULL ( op ,
" No nullptr should be grad_pending op " ) ;
if ( visited_preceding_ops . count ( op ) = = 0 ) {
visited_preceding_ops . insert ( op ) ;
grad_op - > InsertGradPendingOps ( op ) ;
}
}
} else {
VLOG ( 5 ) < < " Hit leaf VarBase "
< < ( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) - > Name ( ) ;
}
} else {
VLOG ( 5 ) < < " Hit leaf VarBase " ;
VLOG ( 5 ) < < " Hit leaf VarBase "
< < ( * ( fwd_var_iter - > second ) ) - > GradVarBase ( ) - > Name ( ) ;
VLOG ( 3) < < " Skip backward output " < < grad_outs . first < < " of "
< < grad_op - > Type ( ) < < " Named: " < < tmp - > Name ( )
< < " , since its Overrided Stop_Gradient is: True " ;
}
}
}