@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 
			
		
	
		
			
				
					limitations  under  the  License .  */  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					# include  "paddle/fluid/operators/controlflow/conditional_block_op.h"  
			
		
	
		
			
				
					# include  "paddle/fluid/operators/assign_op.h"  
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					namespace  paddle  {  
			
		
	
		
			
				
					namespace  operators  {  
			
		
	
	
		
			
				
					
						
							
								 
						
						
							
								 
						
						
					 
				
				@ -58,13 +59,15 @@ class ConditionalBlockOp : public ConditionalOp {
 
			
		
	
		
			
				
					      scopes - > resize ( 1 ) ; 
 
			
		
	
		
			
				
					      scopes - > front ( )  =  & scope . NewScope ( ) ; 
 
			
		
	
		
			
				
					      auto  & cur_scope  =  * scopes - > front ( ) ; 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					      framework : : Executor  exec ( dev_place ) ; 
 
			
		
	
		
			
				
					      auto  * block  =  Attr < framework : : BlockDesc  * > ( " sub_block " ) ; 
 
			
		
	
		
			
				
					      VLOG ( 3 )  < <  " Conditional block.idx =  "  < <  block - > ID ( ) 
 
			
		
	
		
			
				
					              < <  " , scope =  "  < <  & cur_scope ; 
 
			
		
	
		
			
				
					      auto  & skip_vars  = 
 
			
		
	
		
			
				
					          Attr < std : : vector < std : : string > > ( ConditionalOp : : kSkipEagerDeletionVars ) ; 
 
			
		
	
		
			
				
					      exec . Run ( * block - > Program ( ) ,  & cur_scope ,  block - > ID ( ) ,  false ,  true , 
 
			
		
	
		
			
				
					               skip_vars ) ; 
 
			
		
	
		
			
				
					               skip_vars ,  /* force_disable_gc */  false , 
 
			
		
	
		
			
				
					               /* keep_kid_scopes */  true ) ; 
 
			
		
	
		
			
				
					    } 
 
			
		
	
		
			
				
					  } 
 
			
		
	
		
			
				
					} ;  
			
		
	
	
		
			
				
					
						
							
								 
						
						
							
								 
						
						
					 
				
				@ -92,60 +95,65 @@ class ConditionalBlockGradOp : public ConditionalOp {
 
			
		
	
		
			
				
					    } 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					    if  ( need_run )  { 
 
			
		
	
		
			
				
					      const  auto  & inputs  =  Inputs ( ConditionalOp : : kInputs ) ; 
 
			
		
	
		
			
				
					      const  auto  & outside_grads  = 
 
			
		
	
		
			
				
					          Outputs ( framework : : GradVarName ( ConditionalOp : : kInputs ) ) ; 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					      std : : vector < std : : string >  inside_grads ; 
 
			
		
	
		
			
				
					      inside_grads . reserve ( inputs . size ( ) ) ; 
 
			
		
	
		
			
				
					      for  ( auto  & in  :  inputs )  { 
 
			
		
	
		
			
				
					        inside_grads . emplace_back ( framework : : GradVarName ( in ) ) ; 
 
			
		
	
		
			
				
					      } 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					      auto  * scope_var  =  scope . FindVar ( Input ( ConditionalOp : : kScope ) ) ; 
 
			
		
	
		
			
				
					      PADDLE_ENFORCE ( scope_var  ! =  nullptr ,  " Must set scope " ) ; 
 
			
		
	
		
			
				
					      PADDLE_ENFORCE_NE ( scope_var ,  nullptr , 
 
			
		
	
		
			
				
					                        platform : : errors : : InvalidArgument ( 
 
			
		
	
		
			
				
					                            " Scope must be set in conditional block op " ) ) ; 
 
			
		
	
		
			
				
					      auto  & scopes  =  scope_var - > Get < std : : vector < framework : : Scope  * > > ( ) ; 
 
			
		
	
		
			
				
					      PADDLE_ENFORCE_GT ( scopes . size ( ) ,  0 , 
 
			
		
	
		
			
				
					                        platform : : errors : : InvalidArgument ( 
 
			
		
	
		
			
				
					                            " Scope must be set in conditional block op " ) ) ; 
 
			
		
	
		
			
				
					      framework : : Scope  & cur_scope  =  * scopes [ 0 ] ; 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					      framework : : Executor  exec ( dev_place ) ; 
 
			
		
	
		
			
				
					      auto  * block  =  Attr < framework : : BlockDesc  * > ( " sub_block " ) ; 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					      const  auto  & ins  =  Inputs ( ConditionalOp : : kInputs ) ; 
 
			
		
	
		
			
				
					      const  auto  & d_ins  = 
 
			
		
	
		
			
				
					          Outputs ( framework : : GradVarName ( ConditionalOp : : kInputs ) ) ; 
 
			
		
	
		
			
				
					      const  auto  & conds  =  Inputs ( ConditionalOp : : kCondition ) ; 
 
			
		
	
		
			
				
					      const  auto  & d_conds  = 
 
			
		
	
		
			
				
					          Outputs ( framework : : GradVarName ( ConditionalOp : : kCondition ) ) ; 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					      std : : vector < std : : string >  ins_conds_grads ; 
 
			
		
	
		
			
				
					      ins_conds_grads . reserve ( ins . size ( )  +  conds . size ( ) ) ; 
 
			
		
	
		
			
				
					      for  ( auto  & in  :  ins )  { 
 
			
		
	
		
			
				
					        ins_conds_grads . emplace_back ( framework : : GradVarName ( in ) ) ; 
 
			
		
	
		
			
				
					      } 
 
			
		
	
		
			
				
					      for  ( auto  & cond  :  conds )  { 
 
			
		
	
		
			
				
					        ins_conds_grads . emplace_back ( framework : : GradVarName ( cond ) ) ; 
 
			
		
	
		
			
				
					      } 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					      VLOG ( 3 )  < <  " Conditional Grad block.idx =  "  < <  block - > ID ( ) 
 
			
		
	
		
			
				
					              < <  " , scope =  "  < <  & cur_scope ; 
 
			
		
	
		
			
				
					      exec . Run ( * block - > Program ( ) ,  & cur_scope ,  block - > ID ( ) ,  false ,  true , 
 
			
		
	
		
			
				
					               ins_conds_grads ) ; 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					      AssignLocalGradientToGlobal ( dev_place ,  cur_scope ,  ins_conds_grads . data ( ) , 
 
			
		
	
		
			
				
					                                  ins . size ( ) ,  d_ins ) ; 
 
			
		
	
		
			
				
					               inside_grads ,  /* force_disable_gc */  false , 
 
			
		
	
		
			
				
					               /* keep_kid_scopes */  false ) ; 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					      AssignLocalGradientToGlobal ( dev_place ,  cur_scope , 
 
			
		
	
		
			
				
					                                  ins_conds_grads . data ( )  +  ins . size ( ) , 
 
			
		
	
		
			
				
					                                  conds . size ( ) ,  d_conds ) ; 
 
			
		
	
		
			
				
					      AssignLocalGradientToParentScope ( dev_place ,  cur_scope ,  scope , 
 
			
		
	
		
			
				
					                                       inside_grads ,  outside_grads ) ; 
 
			
		
	
		
			
				
					    } 
 
			
		
	
		
			
				
					  } 
 
			
		
	
		
			
				
					
 
			
		
	
		
			
				
					 private : 
 
			
		
	
		
			
				
					  void  AssignLocalGradientTo Global ( 
 
			
		
	
		
			
				
					  void  AssignLocalGradientToParentScope ( 
 
			
		
	
		
			
				
					      const  platform : : Place  & place ,  const  framework : : Scope  & cur_scope , 
 
			
		
	
		
			
				
					      const  std : : string  * p_grad_names ,  size_t  p_grad_names_num , 
 
			
		
	
		
			
				
					      const  std : : vector < std : : string >  & pg_names )  const  { 
 
			
		
	
		
			
				
					    for  ( size_t  i  =  0 ;  i  <  p_grad_names_num ;  + + i )  { 
 
			
		
	
		
			
				
					      auto  out_grad_name  =  pg_names [ i ] ; 
 
			
		
	
		
			
				
					      const  auto  & in_grad_name  =  p_grad_names [ i ] ; 
 
			
		
	
		
			
				
					      auto  * in_var  =  cur_scope . FindVar ( in_grad_name ) ; 
 
			
		
	
		
			
				
					      if  ( in_var  = =  nullptr )  { 
 
			
		
	
		
			
				
					      const  framework : : Scope  & parent_scope , 
 
			
		
	
		
			
				
					      const  std : : vector < std : : string >  & inside_grads , 
 
			
		
	
		
			
				
					      const  std : : vector < std : : string >  & outside_grads )  const  { 
 
			
		
	
		
			
				
					    for  ( size_t  i  =  0 ;  i  <  outside_grads . size ( ) ;  + + i )  { 
 
			
		
	
		
			
				
					      const  std : : string  & outside_grad_name  =  outside_grads [ i ] ; 
 
			
		
	
		
			
				
					      const  std : : string  & inside_grad_name  =  inside_grads [ i ] ; 
 
			
		
	
		
			
				
					      VLOG ( 4 )  < <  " inside_grad_name =  "  < <  inside_grad_name 
 
			
		
	
		
			
				
					              < <  " , outside_grad_name =  "  < <  outside_grad_name ; 
 
			
		
	
		
			
				
					      framework : : Variable  * inside_var  = 
 
			
		
	
		
			
				
					          cur_scope . FindLocalVar ( inside_grad_name ) ; 
 
			
		
	
		
			
				
					      if  ( inside_var  = =  nullptr )  { 
 
			
		
	
		
			
				
					        continue ; 
 
			
		
	
		
			
				
					      } 
 
			
		
	
		
			
				
					      auto  new_in_grad_name  =  cur_scope . Rename ( in_grad_name ) ; 
 
			
		
	
		
			
				
					      auto  assign  =  framework : : OpRegistry : : CreateOp ( 
 
			
		
	
		
			
				
					          " assign " ,  { { " X " ,  { new_in_grad_name } } } ,  { { " Out " ,  { out_grad_name } } } , 
 
			
		
	
		
			
				
					          framework : : AttributeMap { } ) ; 
 
			
		
	
		
			
				
					      assign - > Run ( cur_scope ,  place ) ; 
 
			
		
	
		
			
				
					      cur_scope . Rename ( new_in_grad_name ,  in_grad_name ) ; 
 
			
		
	
		
			
				
					      framework : : Variable  * outside_var  = 
 
			
		
	
		
			
				
					          parent_scope . FindVar ( outside_grad_name ) ; 
 
			
		
	
		
			
				
					      if  ( outside_var  = =  nullptr )  { 
 
			
		
	
		
			
				
					        continue ; 
 
			
		
	
		
			
				
					      } 
 
			
		
	
		
			
				
					      platform : : DeviceContext  * dev_ctx  = 
 
			
		
	
		
			
				
					          platform : : DeviceContextPool : : Instance ( ) . Get ( place ) ; 
 
			
		
	
		
			
				
					      framework : : VisitVarType ( * inside_var , 
 
			
		
	
		
			
				
					                              AssignFunctor ( outside_var ,  * dev_ctx ) ) ; 
 
			
		
	
		
			
				
					    } 
 
			
		
	
		
			
				
					  } 
 
			
		
	
		
			
				
					} ;  
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -154,17 +162,11 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
 
			
		
	
		
			
				
					 public : 
 
			
		
	
		
			
				
					  void  operator ( ) ( framework : : InferShapeContext  * context )  const  override  { 
 
			
		
	
		
			
				
					    PADDLE_ENFORCE ( context - > HasInputs ( ConditionalOp : : kCondition ) ) ; 
 
			
		
	
		
			
				
					    if  ( context - > HasInputs ( ConditionalOp : : kInputs ) )  { 
 
			
		
	
		
			
				
					      PADDLE_ENFORCE ( 
 
			
		
	
		
			
				
					          context - > HasOutputs ( framework : : GradVarName ( ConditionalOp : : kInputs ) ) ) ; 
 
			
		
	
		
			
				
					    if  ( context - > HasInputs ( ConditionalOp : : kInputs )  & & 
 
			
		
	
		
			
				
					        context - > HasOutputs ( framework : : GradVarName ( ConditionalOp : : kInputs ) ) )  { 
 
			
		
	
		
			
				
					      context - > SetOutputsDim ( framework : : GradVarName ( ConditionalOp : : kInputs ) , 
 
			
		
	
		
			
				
					                             context - > GetInputsDim ( ConditionalOp : : kInputs ) ) ; 
 
			
		
	
		
			
				
					    } 
 
			
		
	
		
			
				
					    if  ( context - > HasOutputs ( 
 
			
		
	
		
			
				
					            framework : : GradVarName ( ConditionalOp : : kCondition ) ) )  { 
 
			
		
	
		
			
				
					      context - > SetOutputsDim ( framework : : GradVarName ( ConditionalOp : : kCondition ) , 
 
			
		
	
		
			
				
					                             context - > GetInputsDim ( ConditionalOp : : kCondition ) ) ; 
 
			
		
	
		
			
				
					    } 
 
			
		
	
		
			
				
					  } 
 
			
		
	
		
			
				
					} ;  
			
		
	
		
			
				
					
 
			
		
	
	
		
			
				
					
						
						
						
							
								 
						
					 
				
				@ -187,8 +189,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
 
			
		
	
		
			
				
					                      this - > OutputGrad ( ConditionalOp : : kOutputs ) ) ; 
 
			
		
	
		
			
				
					    grad_op - > SetInput ( ConditionalOp : : kScope , 
 
			
		
	
		
			
				
					                      this - > Output ( ConditionalOp : : kScope ) ) ; 
 
			
		
	
		
			
				
					    grad_op - > SetOutput ( framework : : GradVarName ( ConditionalOp : : kCondition ) , 
 
			
		
	
		
			
				
					                       this - > InputGrad ( ConditionalOp : : kCondition ,  false ) ) ; 
 
			
		
	
		
			
				
					    grad_op - > SetOutput ( framework : : GradVarName ( ConditionalOp : : kInputs ) , 
 
			
		
	
		
			
				
					                       this - > InputGrad ( ConditionalOp : : kInputs ,  false ) ) ; 
 
			
		
	
		
			
				
					    grad_op - > SetBlockAttr ( " sub_block " ,  this - > grad_block_ [ 0 ] ) ;