@ -22,14 +22,11 @@ from paddle.fluid import unique_name
from paddle . fluid . dygraph . dygraph_to_static . static_analysis import AstNodeWrapper
from paddle . fluid . dygraph . dygraph_to_static . static_analysis import NodeVarType
from paddle . fluid . dygraph . dygraph_to_static . static_analysis import StaticAnalysisVisitor
from paddle . fluid . dygraph . dygraph_to_static . utils import ast_to_source_code
from paddle . fluid . dygraph . dygraph_to_static . utils import generate_name_node
from paddle . fluid . dygraph . dygraph_to_static . utils import get_attribute_full_name
from paddle . fluid . dygraph . dygraph_to_static . utils import is_control_flow_to_transform
from paddle . fluid . dygraph . dygraph_to_static . utils import ForNodeVisitor
from paddle . fluid . dygraph . dygraph_to_static . utils import RenameTransformer
from paddle . fluid . dygraph . dygraph_to_static . variable_trans_func import create_static_variable_gast_node
from paddle . fluid . dygraph . dygraph_to_static . variable_trans_func import to_static_variable_gast_node
__all__ = [ ' LoopTransformer ' , ' NameVisitor ' ]
@ -89,7 +86,8 @@ class NameVisitor(gast.NodeVisitor):
# Mapping from gast.While/gast.For to variable nodes
self . before_loop_body_vars = defaultdict ( set )
self . in_loop_vars = defaultdict ( set )
# NOTE: Use ordered list as dict value
self . in_loop_vars = defaultdict ( list )
# Mapping from gast.While/gast.For to variable nodes which is condition
# of loop or being modified during the loop
@ -103,11 +101,6 @@ class NameVisitor(gast.NodeVisitor):
self . visit ( root_node )
def is_control_flow_loop ( self , node ) :
need_transform = is_control_flow_to_transform (
node , self . static_analysis_visitor )
return need_transform
def get_loop_var_names ( self , node ) :
assert isinstance (
node , ( gast . While , gast . For ) ) , " Input node is not gast loop node "
@ -115,7 +108,15 @@ class NameVisitor(gast.NodeVisitor):
create_var_names = set ( )
read_context = { type ( gast . Load ( ) ) , type ( gast . AugLoad ( ) ) }
in_loop_vars = self . in_loop_vars [ node ]
in_loop_vars_list = self . in_loop_vars [ node ]
# get dict `var_name_to_ctxs`
var_name_to_ctxs = defaultdict ( list )
for var_node in in_loop_vars_list :
var_name_to_ctxs [ self . _var_node_to_name ( var_node ) ] . append (
var_node . ctx )
in_loop_vars = set ( in_loop_vars_list )
in_loop_name_strs = self . _var_nodes_to_names ( in_loop_vars )
before_loop_body_vars = self . before_loop_body_vars [ node ]
@ -160,6 +161,22 @@ class NameVisitor(gast.NodeVisitor):
# vars out
loop_var_names . add ( name )
create_var_names . add ( name )
else :
# If a variable is used and created in loop, but used before created,
# it should be in loop_var and we should create it.
# For example, `var_a` should be in loop_var and we should create it.
#
# res = 0
# for i, x in enumerate(x_array):
# if i > 2:
# x = func1(var_a)
# var_a = func2(x)
#
if isinstance ( var_name_to_ctxs [ name ] [ 0 ] , gast . Load ) :
loop_var_names . add ( name )
create_var_names . add ( name )
return loop_var_names , create_var_names
@ -176,7 +193,7 @@ class NameVisitor(gast.NodeVisitor):
type ( gast . Store ( ) ) , type ( gast . AugStore ( ) ) , type ( gast . Del ( ) )
}
for loop_node in self . current_loop :
self . in_loop_vars [ loop_node ] . a d d( node )
self . in_loop_vars [ loop_node ] . a ppen d( node )
if type ( node . ctx ) in write_context :
self . write_in_loop [ loop_node ] . add ( node )
if self . in_condition :
@ -219,7 +236,7 @@ class NameVisitor(gast.NodeVisitor):
self . current_seen_vars . add ( node )
for loop_node in self . current_loop :
self . in_loop_vars [ loop_node ] . a d d( node )
self . in_loop_vars [ loop_node ] . a ppen d( node )
# sub-nodes are visited during get_attribute_full_name and we shouldn't
# visit again
@ -367,27 +384,25 @@ class LoopTransformer(gast.NodeTransformer):
def get_for_stmt_nodes ( self , node ) :
# TODO: consider for - else in python
# 1. check whether need to transform
# NOTE: Current need transform cases:
# 1). for x in range(VarBase[0]|VarBase.numpy()[0])
# 2). for x in VarBase|VarBase.numpy()
# 3). for i, x in enumerate(VarBase|VarBase.numpy())
if not self . name_visitor . is_control_flow_loop ( node ) :
return [ node ]
# 2. get key statements for different cases
# NOTE: three key statements:
# 1. get key statements for different cases
# NOTE 1: three key statements:
# 1). init_stmts: list[node], prepare nodes of for loop, may not only one
# 2). cond_stmt: node, condition node to judge whether continue loop
# 3). body_stmts: list[node], updated loop body, sometimes we should change
# the original statement in body, not just append new statement
#
# NOTE 2: The following `for` statements will be transformed to `while` statements:
# 1). for x in range(*)
# 2). for x in iter_var
# 3). for i, x in enumerate(*)
current_for_node_parser = ForNodeVisitor ( node )
stmts_tuple = current_for_node_parser . parse ( )
if stmts_tuple is None :
return [ node ]
init_stmts , cond_stmt , body_stmts = stmts_tuple
# 3 . get original loop vars
# 2 . get original loop vars
loop_var_names , create_var_names = self . name_visitor . get_loop_var_names (
node )
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
@ -402,7 +417,7 @@ class LoopTransformer(gast.NodeTransformer):
if iter_var_name not in create_var_names :
loop_var_names . remove ( iter_var_name )
# 4 . prepare result statement list
# 3 . prepare result statement list
new_stmts = [ ]
# Python can create variable in loop and use it out of loop, E.g.
#
@ -415,13 +430,10 @@ class LoopTransformer(gast.NodeTransformer):
if " . " not in name :
new_stmts . append ( create_static_variable_gast_node ( name ) )
# 5 . append init statements
# 4 . append init statements
new_stmts . extend ( init_stmts )
# for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
for name in loop_var_names :
new_stmts . append ( to_static_variable_gast_node ( name ) )
# 6 . create & append condition function node
# 5. create & append condition function node
condition_func_node = gast . FunctionDef (
name = unique_name . generate ( FOR_CONDITION_PREFIX ) ,
args = gast . arguments (
@ -449,7 +461,7 @@ class LoopTransformer(gast.NodeTransformer):
name , unique_name . generate ( GENERATE_VARIABLE_PREFIX ) )
new_stmts . append ( condition_func_node )
# 7 . create & append loop body function node
# 6 . create & append loop body function node
# append return values for loop body
body_stmts . append (
gast . Return ( value = generate_name_node (
@ -481,7 +493,7 @@ class LoopTransformer(gast.NodeTransformer):
name , unique_name . generate ( GENERATE_VARIABLE_PREFIX ) )
new_stmts . append ( body_func_node )
# 8. create & append while loop node
# 7. create & append while loop node
while_loop_node = create_while_node ( condition_func_node . name ,
body_func_node . name , loop_var_names )
new_stmts . append ( while_loop_node )