@ -20,6 +20,7 @@ import gast
from collections import defaultdict
from paddle . fluid import unique_name
from paddle . fluid . dygraph . dygraph_to_static . static_analysis import AstNodeWrapper
from paddle . fluid . dygraph . dygraph_to_static . static_analysis import NodeVarType
from paddle . fluid . dygraph . dygraph_to_static . static_analysis import StaticAnalysisVisitor
from paddle . fluid . dygraph . dygraph_to_static . utils import ast_to_source_code
from paddle . fluid . dygraph . dygraph_to_static . utils import generate_name_node
@ -134,6 +135,12 @@ class NameVisitor(gast.NodeVisitor):
self . before_loop_body_vars = defaultdict ( set )
self . in_loop_vars = defaultdict ( set )
# Mapping from gast.While/gast.For to variable nodes which is condition
# of loop or being modified during the loop
self . write_in_loop = defaultdict ( set )
self . condition_vars = defaultdict ( set )
self . in_condition = False
self . static_analysis_visitor = StaticAnalysisVisitor ( root_node )
self . node_to_wrapper_map = self . static_analysis_visitor . get_node_to_wrapper_map (
)
@ -158,14 +165,36 @@ class NameVisitor(gast.NodeVisitor):
after_loop_vars = self . current_seen_vars - before_loop_body_vars - in_loop_vars
after_loop_name_strs = self . _var_nodes_to_names ( after_loop_vars ,
read_context )
condition_vars = self . condition_vars [ node ]
condition_names = self . _var_nodes_to_names ( condition_vars )
write_vars = self . write_in_loop [ node ]
write_names = self . _var_nodes_to_names ( write_vars )
name_to_type = { }
for var in in_loop_vars :
wrapper = self . node_to_wrapper_map [ var ]
name_to_type [ self . _var_node_to_name ( var ) ] = wrapper . node_var_type
for name in in_loop_name_strs :
if name in before_loop_name_strs :
# If a variable is used in loop and created before loop, it
# should be in loop_var as input
# If a variable is used in loop and created before loop
# If this var is a basic variable and read-only and not
# condition var, it may not be loop_var else it should
# be in loop_var as input
if ( not name in condition_names ) and (
not name in write_names
) and self . _node_var_type_is_basic ( name_to_type [ name ] ) :
continue
loop_var_names . add ( name )
elif name in after_loop_name_strs :
# If a variable is created in the while loop and read after
# loop, it should be in loop_var and we should create it
# because name in after_loop_name must be initialized in loop
# So it is write-only, we don't have to filter read-only basic
# vars out
loop_var_names . add ( name )
create_var_names . add ( name )
return loop_var_names , create_var_names
@ -179,8 +208,15 @@ class NameVisitor(gast.NodeVisitor):
return
self . current_seen_vars . add ( node )
write_context = {
type ( gast . Store ( ) ) , type ( gast . AugStore ( ) ) , type ( gast . Del ( ) )
}
for loop_node in self . current_loop :
self . in_loop_vars [ loop_node ] . add ( node )
if type ( node . ctx ) in write_context :
self . write_in_loop [ loop_node ] . add ( node )
if self . in_condition :
self . condition_vars [ loop_node ] . add ( node )
self . generic_visit ( node )
def visit_FunctionDef ( self , node ) :
@ -217,21 +253,28 @@ class NameVisitor(gast.NodeVisitor):
if attr_full_name . startswith ( " self. " ) :
return
self . current_seen_vars . add ( node )
for loop_node in self . current_loop :
self . in_loop_vars [ loop_node ] . add ( node )
# sub-nodes are visited during get_attribute_full_name and we shouldn't
# visit again
def visit_For ( self , node ) :
self . current_loop . append ( node )
self . in_condition = True
self . visit ( node . target )
self . visit ( node . iter )
self . in_condition = False
self . before_loop_body_vars [ node ] = copy . copy ( self . current_seen_vars )
self . generic_visit ( node )
self . current_loop . pop ( )
def visit_While ( self , node ) :
self . current_loop . append ( node )
self . in_condition = True
self . visit ( node . test )
self . in_condition = False
self . before_loop_body_vars [ node ] = copy . copy ( self . current_seen_vars )
self . generic_visit ( node )
self . current_loop . pop ( )
@ -240,12 +283,25 @@ class NameVisitor(gast.NodeVisitor):
ret = set ( )
for node in node_set :
if ctx_filter_set is None or type ( node . ctx ) in ctx_filter_set :
if isinstance ( node , gast . Name ) :
ret . add ( node . id )
elif isinstance ( node , gast . Attribute ) :
ret . add ( get_attribute_full_name ( node ) )
ret . add ( self . _var_node_to_name ( node ) )
return ret
def _var_node_to_name ( self , node ) :
if isinstance ( node , gast . Name ) :
return node . id
elif isinstance ( node , gast . Attribute ) :
return get_attribute_full_name ( node )
def _node_var_type_is_basic ( self , node_var_type ) :
basic_types = {
NodeVarType . BOOLEAN , NodeVarType . INT , NodeVarType . FLOAT ,
NodeVarType . STRING
}
for t in node_var_type :
if t in basic_types :
return True
return False
def _is_call_func_name_node ( self , node ) :
parent_node = self . node_to_wrapper_map [ node ] . parent . node
if isinstance ( parent_node , gast . Call ) and parent_node . func == node :