@ -19,6 +19,7 @@ import gast
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					from  paddle . fluid  import  unique_name 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					from  paddle . fluid . dygraph . dygraph_to_static . utils  import  index_in_list 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					from  paddle . fluid . dygraph . dygraph_to_static . utils  import  ForNodeVisitor 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					from  paddle . fluid . dygraph . dygraph_to_static . utils  import  BaseNodeVisitor 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					from  paddle . fluid . dygraph . dygraph_to_static . variable_trans_func  import  create_fill_constant_node 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					__all__  =  [ ' BreakContinueTransformer ' ] 
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -83,7 +84,7 @@ class ForToWhileTransformer(gast.NodeTransformer):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        return  init_stmts 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  BreakContinueTransformer ( gast. NodeTransforme  r) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  BreakContinueTransformer ( BaseNodeVisito r) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    """ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    Rewrite  ' break '  and  ' continue '  key  words  in  a  if - else  python  way  to  make 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    it  equivalent  to  original  control  flow 
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -103,41 +104,23 @@ class BreakContinueTransformer(gast.NodeTransformer):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        set  continue  to  False  at  the  beginning  of  each  loop 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        TODO :  more  details  should  be  summarized  as  design  document 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    Note :  The  class  is  inherited  from  BaseNodeVisitor  instead  of  NodeTransformer , 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					          because  ancestor  nodes  will  be  modified  inplace  for  ` Break / Continue `  here . 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					          In  general ,  we  recommend  to  inheriting  NodeTransformer  to  modify  node ! 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    """ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  __init__ ( self ,  wrapper_root ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        super ( BreakContinueTransformer ,  self ) . __init__ ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . wrapper_root  =  wrapper_root 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . root  =  wrapper_root . node 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . ancestor_nodes  =  [ ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  transform ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . visit ( self . root ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  generic_visit ( self ,  node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # TODO: because we change ancestor nodes during visit_Break/Continue, 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # not current node, so generic_visit of NodeTransformer will visit node 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # which may be deleted. To prevent that node being added into 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # transformed AST, I have to self-write a generic_visit, but this is 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # NOT a good thing. Considering refactorying this whole class. 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        for  field ,  value  in  gast . iter_fields ( node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            if  isinstance ( value ,  list ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                for  item  in  value : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                    if  isinstance ( item ,  gast . AST ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                        self . visit ( item ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            elif  isinstance ( value ,  gast . AST ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                self . visit ( value ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  visit ( self ,  node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . ancestor_nodes . append ( node ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        method  =  ' visit_ '  +  node . __class__ . __name__ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        visitor  =  getattr ( self ,  method ,  self . generic_visit ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        ret  =  visitor ( node ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . ancestor_nodes . pop ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        return  ret 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  visit_Break ( self ,  node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        loop_node_index  =  self . _find_ancestor_loop_index ( node ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        loop_node_index  =  _find_ancestor_loop_index ( node ,  self . ancestor_nodes ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        assert  loop_node_index  !=  - 1 ,  " SyntaxError:  ' break '  outside loop " 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        loop_node  =  self . ancestor_nodes [ loop_node_index ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -150,7 +133,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        first_block_index  =  self . _remove_stmts_after_break_continue ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            node ,  variable_name ,  loop_node_index ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # 3. Add 'if  V' for stmts in ancestor blocks between the first one
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # 3. Add 'if  not  V' for stmts in ancestor blocks between the first one
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # (exclusive) and the ancestor loop (inclusive) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . _replace_if_stmt ( loop_node_index ,  first_block_index ,  variable_name ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -165,6 +148,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                ctx = gast . Load ( ) , 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                annotation = None , 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                type_comment = None ) ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  isinstance ( loop_node ,  gast . While ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            loop_node . test  =  gast . BoolOp ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                op = gast . And ( ) ,  values = [ loop_node . test ,  cond_var_node ] ) 
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -175,7 +159,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            for_to_while . transform ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  visit_Continue ( self ,  node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        loop_node_index  =  self . _find_ancestor_loop_index ( node ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        loop_node_index  =  _find_ancestor_loop_index ( node ,  self . ancestor_nodes ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        assert  loop_node_index  !=  - 1 ,  " SyntaxError:  ' continue '  outside loop " 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        loop_node  =  self . ancestor_nodes [ loop_node_index ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -188,7 +172,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        first_block_index  =  self . _remove_stmts_after_break_continue ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            node ,  variable_name ,  loop_node_index ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # 3. Add 'if  V' for stmts in ancestor blocks between the first one
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # 3. Add 'if  not  V' for stmts in ancestor blocks between the first one
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # (exclusive) and the ancestor loop (inclusive) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . _replace_if_stmt ( loop_node_index ,  first_block_index ,  variable_name ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -215,15 +199,6 @@ class BreakContinueTransformer(gast.NodeTransformer):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        return  first_block_index 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  _replace_break_continue_in_stmt_list ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            self ,  stmt_list ,  break_continue_node ,  break_continue_name ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        i  =  index_in_list ( stmt_list ,  break_continue_node ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  i  ==  - 1 : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            return  False 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        assign_true_node  =  create_fill_constant_node ( break_continue_name ,  True ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        stmt_list [ i : ]  =  [ assign_true_node ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        return  True 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  _replace_if_stmt ( self ,  loop_node_index ,  first_block_index , 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                         break_continue_name ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        for  i  in  range ( first_block_index  -  1 ,  loop_node_index  -  1 ,  - 1 ) : 
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -239,6 +214,15 @@ class BreakContinueTransformer(gast.NodeTransformer):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                        cur_node . orelse ,  son_node ,  break_continue_name ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                continue 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  _replace_break_continue_in_stmt_list ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            self ,  stmt_list ,  break_continue_node ,  break_continue_name ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        i  =  index_in_list ( stmt_list ,  break_continue_node ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  i  ==  - 1 : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            return  False 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        assign_true_node  =  create_fill_constant_node ( break_continue_name ,  True ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        stmt_list [ i : ]  =  [ assign_true_node ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        return  True 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  _replace_after_node_to_if_in_stmt_list ( self ,  stmt_list ,  node , 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                               break_continue_name ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        i  =  index_in_list ( stmt_list ,  node ) 
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -282,8 +266,110 @@ class BreakContinueTransformer(gast.NodeTransformer):
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        stmt_list . insert ( i ,  stmt_node ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        return  True 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  _find_ancestor_loop_index ( self ,  node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        for  i  in  range ( len ( self . ancestor_nodes )  -  1 ,  - 1 ,  - 1 ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            if  isinstance ( self . ancestor_nodes [ i ] ,  ( gast . For ,  gast . While ) ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					def  _find_ancestor_loop_index ( node ,  ancestor_nodes ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    for  i  in  range ( len ( ancestor_nodes )  -  1 ,  - 1 ,  - 1 ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  isinstance ( ancestor_nodes [ i ] ,  ( gast . For ,  gast . While ) ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            return  i 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    return  - 1 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					class  BreakTransformOptimizer ( BaseNodeVisitor ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    """ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    In  specific  pattern ,  the  transformed  code  could  be  optimized  by  joining  the  
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    If . test  with  while . test .  
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    Currently  supported  pattern  is : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    ` ` ` 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        while  cond1 :             while  cond1  and  not  cond2 : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            if  cond2 :     - - - >        do_something ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                break 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            do_something ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    ` ` ` 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    See  following  example : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    >> >  def  foo ( x ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    . . .      i  =  paddle . to_tensor ( 1 ,  dtype = ' int32 ' ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    . . .      while  i  <  10 : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    . . .          if  x . mean ( )  >  5 : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    . . .              break 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    . . .          x  + =  i 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    . . .          i  + =  1 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    . . .      return  x 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    The  generated  code  after  applying  optimization  will  be : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    ` ` ` 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        def  foo ( x ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            i  =  paddle . to_tensor ( 1 ,  dtype = ' int32 ' ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            while  i  <  10  and  not  x . mean ( )  >  5 : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                x  + =  i 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                i  + =  1 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            return  x 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    ` ` ` 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    It  can  avoid  wrapping  all  ops  after  ` break `  statement  into  ` cond_op `  that  
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    usually  brings  very  heavy  overhead . 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    """ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  __init__ ( self ,  wrapper_root ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        super ( BreakTransformOptimizer ,  self ) . __init__ ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . wrapper_root  =  wrapper_root 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . root  =  wrapper_root . node 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  transform ( self ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        self . visit ( self . root ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  visit_Break ( self ,  node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        loop_node_index  =  _find_ancestor_loop_index ( node ,  self . ancestor_nodes ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        assert  loop_node_index  !=  - 1 ,  " SyntaxError:  ' break '  outside loop " 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        loop_node  =  self . ancestor_nodes [ loop_node_index ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  self . _is_break_cond_pattern ( node ,  loop_node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            cond_var_node  =  self . _join_with_while_cond ( node ,  loop_node ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            if  isinstance ( loop_node ,  gast . While ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                loop_node . test  =  gast . BoolOp ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                    op = gast . And ( ) ,  values = [ loop_node . test ,  cond_var_node ] ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            elif  isinstance ( loop_node ,  gast . For ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                parent_node  =  self . ancestor_nodes [ loop_node_index  -  1 ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                for_to_while  =  ForToWhileTransformer ( parent_node ,  loop_node , 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                                     cond_var_node ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                for_to_while . transform ( ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  _is_break_cond_pattern ( self ,  break_node ,  loop_node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        """ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        Judge  whether  if  match  the  pattern  to  join  ` If . test `  with  ` while . test ` 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        """ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # while/for -> if -> break 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  len ( self . ancestor_nodes )  <  3  or  self . ancestor_nodes [ - 3 ]  !=  loop_node : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            return  False 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        assert  self . ancestor_nodes [ - 1 ]  ==  break_node 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        parent_if_node  =  self . ancestor_nodes [ - 2 ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        is_matched  =  False 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        if  isinstance ( parent_if_node ,  gast . If ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            # gast.If only contains `break` 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            break_first_in_if  =  parent_if_node . body [ 0 ]  ==  break_node  and  len ( 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                parent_if_node . orelse )  ==  0 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            # gast.If is first node of loop_node 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            if_first_in_loop  =  loop_node . body [ 0 ]  ==  parent_if_node 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            is_matched  =  if_first_in_loop  and  break_first_in_if 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        return  is_matched 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    def  _join_with_while_cond ( self ,  break_node ,  loop_node ) : 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        """ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        Join  the  ` If . test `  with  ` While . test `  together . 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        """ 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        parent_if_node  =  self . ancestor_nodes [ - 2 ] 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        cond_var_node  =  gast . UnaryOp ( op = gast . Not ( ) ,  operand = parent_if_node . test ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        # remove the gast.If node that contains the gast.Break. 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        assert  loop_node . body [ 0 ]  ==  parent_if_node 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        loop_node . body . pop ( 0 ) 
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        return  cond_var_node