@ -26,6 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle . fluid . dygraph . dygraph_to_static . utils import generate_name_node
from paddle . fluid . dygraph . dygraph_to_static . utils import get_constant_variable_node
from paddle . fluid . dygraph . dygraph_to_static . utils import get_attribute_full_name
from paddle . fluid . dygraph . dygraph_to_static . utils import RenameTransformer
from paddle . fluid . dygraph . dygraph_to_static . variable_trans_func import create_static_variable_gast_node
from paddle . fluid . dygraph . dygraph_to_static . variable_trans_func import to_static_variable_gast_node
@ -36,6 +37,7 @@ WHILE_BODY_PREFIX = 'while_body'
FOR_CONDITION_PREFIX = ' for_loop_condition '
FOR_BODY_PREFIX = ' for_loop_body '
GENERATE_VARIABLE_PREFIX = ' generate_variable '
def create_while_node ( condition_name , body_name , loop_var_names ) :
@ -440,7 +442,8 @@ class LoopTransformer(gast.NodeTransformer):
#
# We need to create static variable for those variables
for name in create_var_names :
new_stmts . append ( create_static_variable_gast_node ( name ) )
if " . " not in name :
new_stmts . append ( create_static_variable_gast_node ( name ) )
new_stmts . append ( init_stmt )
@ -468,6 +471,11 @@ class LoopTransformer(gast.NodeTransformer):
decorator_list = [ ] ,
returns = None ,
type_comment = None )
for name in loop_var_names :
if " . " in name :
rename_transformer = RenameTransformer ( condition_func_node )
rename_transformer . rename (
name , unique_name . generate ( GENERATE_VARIABLE_PREFIX ) )
new_stmts . append ( condition_func_node )
new_body = node . body
@ -495,6 +503,11 @@ class LoopTransformer(gast.NodeTransformer):
decorator_list = [ ] ,
returns = None ,
type_comment = None )
for name in loop_var_names :
if " . " in name :
rename_transformer = RenameTransformer ( body_func_node )
rename_transformer . rename (
name , unique_name . generate ( GENERATE_VARIABLE_PREFIX ) )
new_stmts . append ( body_func_node )
while_loop_node = create_while_node ( condition_func_node . name ,
@ -521,7 +534,8 @@ class LoopTransformer(gast.NodeTransformer):
#
# We need to create static variable for those variables
for name in create_var_names :
new_stmts . append ( create_static_variable_gast_node ( name ) )
if " . " not in name :
new_stmts . append ( create_static_variable_gast_node ( name ) )
# while x < 10 in dygraph should be convert into static tensor < 10
for name in loop_var_names :
@ -550,6 +564,11 @@ class LoopTransformer(gast.NodeTransformer):
decorator_list = [ ] ,
returns = None ,
type_comment = None )
for name in loop_var_names :
if " . " in name :
rename_transformer = RenameTransformer ( condition_func_node )
rename_transformer . rename (
name , unique_name . generate ( GENERATE_VARIABLE_PREFIX ) )
new_stmts . append ( condition_func_node )
new_body = node . body
@ -576,6 +595,11 @@ class LoopTransformer(gast.NodeTransformer):
decorator_list = [ ] ,
returns = None ,
type_comment = None )
for name in loop_var_names :
if " . " in name :
rename_transformer = RenameTransformer ( body_func_node )
rename_transformer . rename (
name , unique_name . generate ( GENERATE_VARIABLE_PREFIX ) )
new_stmts . append ( body_func_node )
while_loop_node = create_while_node ( condition_func_node . name ,