@ -19,6 +19,7 @@ import gast
from paddle . fluid import unique_name
from paddle . fluid . dygraph . dygraph_to_static . utils import index_in_list
from paddle . fluid . dygraph . dygraph_to_static . utils import ForNodeVisitor
from paddle . fluid . dygraph . dygraph_to_static . utils import BaseNodeVisitor
from paddle . fluid . dygraph . dygraph_to_static . variable_trans_func import create_fill_constant_node
__all__ = [ ' BreakContinueTransformer ' ]
@ -83,7 +84,7 @@ class ForToWhileTransformer(gast.NodeTransformer):
return init_stmts
class BreakContinueTransformer ( gast. NodeTransforme r) :
class BreakContinueTransformer ( BaseNodeVisito r) :
"""
Rewrite ' break ' and ' continue ' key words in a if - else python way to make
it equivalent to original control flow
@ -103,41 +104,23 @@ class BreakContinueTransformer(gast.NodeTransformer):
set continue to False at the beginning of each loop
TODO : more details should be summarized as design document
Note : The class is inherited from BaseNodeVisitor instead of NodeTransformer ,
because ancestor nodes will be modified inplace for ` Break / Continue ` here .
In general , we recommend to inheriting NodeTransformer to modify node !
"""
def __init__ ( self , wrapper_root ) :
super ( BreakContinueTransformer , self ) . __init__ ( )
self . wrapper_root = wrapper_root
self . root = wrapper_root . node
self . ancestor_nodes = [ ]
def transform ( self ) :
self . visit ( self . root )
def generic_visit ( self , node ) :
# TODO: because we change ancestor nodes during visit_Break/Continue,
# not current node, so generic_visit of NodeTransformer will visit node
# which may be deleted. To prevent that node being added into
# transformed AST, I have to self-write a generic_visit, but this is
# NOT a good thing. Considering refactorying this whole class.
for field , value in gast . iter_fields ( node ) :
if isinstance ( value , list ) :
for item in value :
if isinstance ( item , gast . AST ) :
self . visit ( item )
elif isinstance ( value , gast . AST ) :
self . visit ( value )
def visit ( self , node ) :
self . ancestor_nodes . append ( node )
method = ' visit_ ' + node . __class__ . __name__
visitor = getattr ( self , method , self . generic_visit )
ret = visitor ( node )
self . ancestor_nodes . pop ( )
return ret
def visit_Break ( self , node ) :
loop_node_index = self . _find_ancestor_loop_index ( node )
loop_node_index = _find_ancestor_loop_index ( node , self . ancestor_nodes )
assert loop_node_index != - 1 , " SyntaxError: ' break ' outside loop "
loop_node = self . ancestor_nodes [ loop_node_index ]
@ -150,7 +133,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
first_block_index = self . _remove_stmts_after_break_continue (
node , variable_name , loop_node_index )
# 3. Add 'if V' for stmts in ancestor blocks between the first one
# 3. Add 'if not V' for stmts in ancestor blocks between the first one
# (exclusive) and the ancestor loop (inclusive)
self . _replace_if_stmt ( loop_node_index , first_block_index , variable_name )
@ -165,6 +148,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
ctx = gast . Load ( ) ,
annotation = None ,
type_comment = None ) )
if isinstance ( loop_node , gast . While ) :
loop_node . test = gast . BoolOp (
op = gast . And ( ) , values = [ loop_node . test , cond_var_node ] )
@ -175,7 +159,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
for_to_while . transform ( )
def visit_Continue ( self , node ) :
loop_node_index = self . _find_ancestor_loop_index ( node )
loop_node_index = _find_ancestor_loop_index ( node , self . ancestor_nodes )
assert loop_node_index != - 1 , " SyntaxError: ' continue ' outside loop "
loop_node = self . ancestor_nodes [ loop_node_index ]
@ -188,7 +172,7 @@ class BreakContinueTransformer(gast.NodeTransformer):
first_block_index = self . _remove_stmts_after_break_continue (
node , variable_name , loop_node_index )
# 3. Add 'if V' for stmts in ancestor blocks between the first one
# 3. Add 'if not V' for stmts in ancestor blocks between the first one
# (exclusive) and the ancestor loop (inclusive)
self . _replace_if_stmt ( loop_node_index , first_block_index , variable_name )
@ -215,15 +199,6 @@ class BreakContinueTransformer(gast.NodeTransformer):
return first_block_index
def _replace_break_continue_in_stmt_list (
self , stmt_list , break_continue_node , break_continue_name ) :
i = index_in_list ( stmt_list , break_continue_node )
if i == - 1 :
return False
assign_true_node = create_fill_constant_node ( break_continue_name , True )
stmt_list [ i : ] = [ assign_true_node ]
return True
def _replace_if_stmt ( self , loop_node_index , first_block_index ,
break_continue_name ) :
for i in range ( first_block_index - 1 , loop_node_index - 1 , - 1 ) :
@ -239,6 +214,15 @@ class BreakContinueTransformer(gast.NodeTransformer):
cur_node . orelse , son_node , break_continue_name ) :
continue
def _replace_break_continue_in_stmt_list (
self , stmt_list , break_continue_node , break_continue_name ) :
i = index_in_list ( stmt_list , break_continue_node )
if i == - 1 :
return False
assign_true_node = create_fill_constant_node ( break_continue_name , True )
stmt_list [ i : ] = [ assign_true_node ]
return True
def _replace_after_node_to_if_in_stmt_list ( self , stmt_list , node ,
break_continue_name ) :
i = index_in_list ( stmt_list , node )
@ -282,8 +266,110 @@ class BreakContinueTransformer(gast.NodeTransformer):
stmt_list . insert ( i , stmt_node )
return True
def _find_ancestor_loop_index ( self , node ) :
for i in range ( len ( self . ancestor_nodes ) - 1 , - 1 , - 1 ) :
if isinstance ( self . ancestor_nodes [ i ] , ( gast . For , gast . While ) ) :
return i
return - 1
def _find_ancestor_loop_index ( node , ancestor_nodes ) :
for i in range ( len ( ancestor_nodes ) - 1 , - 1 , - 1 ) :
if isinstance ( ancestor_nodes [ i ] , ( gast . For , gast . While ) ) :
return i
return - 1
class BreakTransformOptimizer ( BaseNodeVisitor ) :
"""
In specific pattern , the transformed code could be optimized by joining the
If . test with while . test .
Currently supported pattern is :
` ` `
while cond1 : while cond1 and not cond2 :
if cond2 : - - - > do_something ( )
break
do_something ( )
` ` `
See following example :
>> > def foo ( x ) :
. . . i = paddle . to_tensor ( 1 , dtype = ' int32 ' )
. . . while i < 10 :
. . . if x . mean ( ) > 5 :
. . . break
. . . x + = i
. . . i + = 1
. . . return x
The generated code after applying optimization will be :
` ` `
def foo ( x ) :
i = paddle . to_tensor ( 1 , dtype = ' int32 ' )
while i < 10 and not x . mean ( ) > 5 :
x + = i
i + = 1
return x
` ` `
It can avoid wrapping all ops after ` break ` statement into ` cond_op ` that
usually brings very heavy overhead .
"""
def __init__ ( self , wrapper_root ) :
super ( BreakTransformOptimizer , self ) . __init__ ( )
self . wrapper_root = wrapper_root
self . root = wrapper_root . node
def transform ( self ) :
self . visit ( self . root )
def visit_Break ( self , node ) :
loop_node_index = _find_ancestor_loop_index ( node , self . ancestor_nodes )
assert loop_node_index != - 1 , " SyntaxError: ' break ' outside loop "
loop_node = self . ancestor_nodes [ loop_node_index ]
if self . _is_break_cond_pattern ( node , loop_node ) :
cond_var_node = self . _join_with_while_cond ( node , loop_node )
if isinstance ( loop_node , gast . While ) :
loop_node . test = gast . BoolOp (
op = gast . And ( ) , values = [ loop_node . test , cond_var_node ] )
elif isinstance ( loop_node , gast . For ) :
parent_node = self . ancestor_nodes [ loop_node_index - 1 ]
for_to_while = ForToWhileTransformer ( parent_node , loop_node ,
cond_var_node )
for_to_while . transform ( )
def _is_break_cond_pattern ( self , break_node , loop_node ) :
"""
Judge whether if match the pattern to join ` If . test ` with ` while . test `
"""
# while/for -> if -> break
if len ( self . ancestor_nodes ) < 3 or self . ancestor_nodes [ - 3 ] != loop_node :
return False
assert self . ancestor_nodes [ - 1 ] == break_node
parent_if_node = self . ancestor_nodes [ - 2 ]
is_matched = False
if isinstance ( parent_if_node , gast . If ) :
# gast.If only contains `break`
break_first_in_if = parent_if_node . body [ 0 ] == break_node and len (
parent_if_node . orelse ) == 0
# gast.If is first node of loop_node
if_first_in_loop = loop_node . body [ 0 ] == parent_if_node
is_matched = if_first_in_loop and break_first_in_if
return is_matched
def _join_with_while_cond ( self , break_node , loop_node ) :
"""
Join the ` If . test ` with ` While . test ` together .
"""
parent_if_node = self . ancestor_nodes [ - 2 ]
cond_var_node = gast . UnaryOp ( op = gast . Not ( ) , operand = parent_if_node . test )
# remove the gast.If node that contains the gast.Break.
assert loop_node . body [ 0 ] == parent_if_node
loop_node . body . pop ( 0 )
return cond_var_node