@ -39,8 +39,35 @@ FOR_CONDITION_PREFIX = 'for_loop_condition'
FOR_BODY_PREFIX = ' for_loop_body '
GENERATE_VARIABLE_PREFIX = ' generate_variable '
ATTRIBUTE_VARIABLE_PREFIX = ' __attribute_variable '
def create_while_node ( condition_name , body_name , loop_var_names ) :
def create_while_nodes ( condition_name , body_name , loop_var_names ) :
"""
Returns a list of gast . Node which represents the calling of Paddle
controlflow while_loop .
Usually , the list just contain 1 statement such as :
[ a , b , c ] = paddle . jit . dy2static . convert_while_loop (
condition_name , body_name , [ a , b , c ] )
where a , b , c are in loop_var_names .
However , if loop_var_names contains attribute such as foo . x , we cannot
assign the attribute as output of convert_while_loop because Python
property is a kind of read - only attribute . To handle the case , we replace
the attributes which are output of convert_while_loop with generated
variables , then if we know the attribute is not read - only at runtime , we
assign the attribute . The created statements are like :
[ a , b , __attribute_variable_1 ] = paddle . jit . dy2static . convert_while_loop (
condition_name , body_name , [ a , b , foo . x ] )
if not isinstance ( getattr ( type ( foo ) , x , None ) , property ) : foo . x = __attribute_variable_1
The number of above statements is not only 1 , that ' s why the return type is
a list of gast . Node .
"""
# NOTE(liym27):
# It's better to parse the source code into an AST node than to customize an AST node
# including child nodes, because it is easy to mistake the ast node type when customizing the node.
@ -48,14 +75,37 @@ def create_while_node(condition_name, body_name, loop_var_names):
# For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name,
# but the type of `foo.x` gast.Attribute.
unique_name_to_origin = { }
# We have to make loop_var_names and assign_loop_var_names with same order
# set doesn't have order so we convert it to list
loop_var_names = list ( loop_var_names )
assign_loop_var_names = [ ]
for name in ( loop_var_names ) :
if " . " in name :
# name is an attribute variable such as foo.x
tmp_attr_name = unique_name . generate ( ATTRIBUTE_VARIABLE_PREFIX )
unique_name_to_origin [ tmp_attr_name ] = name
assign_loop_var_names . append ( tmp_attr_name )
else :
assign_loop_var_names . append ( name )
while_func_name = " paddle.jit.dy2static.convert_while_loop "
while_node_str = " [ {} ] = {} ( {} , {} , [ {} ]) " . format (
" , " . join ( loop_var_names ) , while_func_name , condition_name , body_name ,
" , " . join ( loop_var_names ) )
" , " . join ( assign_loop_var_names ) , while_func_name , condition_name ,
body_name , " , " . join ( loop_var_names ) )
while_node = gast . parse ( while_node_str ) . body [ 0 ]
return while_node
ret = [ while_node ]
for tmp_attr_name in unique_name_to_origin :
origin_attr_var = unique_name_to_origin [ tmp_attr_name ]
dot_pos = origin_attr_var . rindex ( " . " )
obj_name = origin_attr_var [ 0 : dot_pos ]
attr_name = origin_attr_var [ dot_pos + 1 : ]
assign_if_not_prop_str = " if not isinstance(getattr(type( {} ), ' {} ' , None), property): {} = {} " . format (
obj_name , attr_name , origin_attr_var , tmp_attr_name )
assign_if_not_prop_node = gast . parse ( assign_if_not_prop_str ) . body [ 0 ]
ret . append ( assign_if_not_prop_node )
return ret
class NameVisitor ( gast . NodeVisitor ) :
@ -573,9 +623,9 @@ class LoopTransformer(gast.NodeTransformer):
new_stmts . append ( body_func_node )
# 7. create & append while loop node
while_loop_node = create_while_node ( condition_func_node . name ,
body_func_node . name , loop_var_names )
new_stmts . append( while_loop_node )
while_loop_node s = create_while_node s (
condition_func_node . name , body_func_node . name , loop_var_names )
new_stmts . extend( while_loop_nodes )
return new_stmts
@ -655,7 +705,7 @@ class LoopTransformer(gast.NodeTransformer):
name , unique_name . generate ( GENERATE_VARIABLE_PREFIX ) )
new_stmts . append ( body_func_node )
while_loop_node = create_while_node ( condition_func_node . name ,
body_func_node . name , loop_var_names )
new_stmts . append( while_loop_node )
while_loop_node s = create_while_node s (
condition_func_node . name , body_func_node . name , loop_var_names )
new_stmts . extend( while_loop_nodes )
return new_stmts