|
|
|
@ -23,10 +23,12 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
|
|
|
|
@ -321,7 +323,7 @@ class LoopTransformer(gast.NodeTransformer):
|
|
|
|
|
def __init__(self, wrapper_root):
|
|
|
|
|
assert isinstance(
|
|
|
|
|
wrapper_root, AstNodeWrapper
|
|
|
|
|
), "Input non-AstNodeWrapper node for the initialization of WhileTransformer."
|
|
|
|
|
), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
|
|
|
|
|
self.wrapper_root = wrapper_root
|
|
|
|
|
self.root = wrapper_root.node
|
|
|
|
|
self.name_visitor = NameVisitor(self.root)
|
|
|
|
@ -355,86 +357,45 @@ class LoopTransformer(gast.NodeTransformer):
|
|
|
|
|
else:
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
def get_for_range_node(self, node):
|
|
|
|
|
if not isinstance(node.iter, gast.Call):
|
|
|
|
|
return None
|
|
|
|
|
if not isinstance(node.iter.func, gast.Name):
|
|
|
|
|
return None
|
|
|
|
|
if node.iter.func.id != "range":
|
|
|
|
|
return None
|
|
|
|
|
return node.iter
|
|
|
|
|
|
|
|
|
|
def get_for_args_stmts(self, iter_name, args_list):
|
|
|
|
|
'''
|
|
|
|
|
Returns 3 gast stmt nodes for argument.
|
|
|
|
|
1. Initailize of iterate variable
|
|
|
|
|
2. Condition for the loop
|
|
|
|
|
3. Statement for changing of iterate variable during the loop
|
|
|
|
|
NOTE(TODO): Python allows to access iteration variable after loop, such
|
|
|
|
|
as "for i in range(10)" will create i = 9 after the loop. But using
|
|
|
|
|
current conversion will make i = 10. We should find a way to change it
|
|
|
|
|
'''
|
|
|
|
|
len_range_args = len(args_list)
|
|
|
|
|
assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
|
|
|
|
|
if len_range_args == 1:
|
|
|
|
|
init_stmt = get_constant_variable_node(iter_name, 0)
|
|
|
|
|
else:
|
|
|
|
|
init_stmt = gast.Assign(
|
|
|
|
|
targets=[
|
|
|
|
|
gast.Name(
|
|
|
|
|
id=iter_name,
|
|
|
|
|
ctx=gast.Store(),
|
|
|
|
|
annotation=None,
|
|
|
|
|
type_comment=None)
|
|
|
|
|
],
|
|
|
|
|
value=args_list[0])
|
|
|
|
|
|
|
|
|
|
range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
|
|
|
|
|
step_node = args_list[2] if len_range_args == 3 else gast.Constant(
|
|
|
|
|
value=1, kind=None)
|
|
|
|
|
|
|
|
|
|
cond_stmt = gast.Compare(
|
|
|
|
|
left=gast.BinOp(
|
|
|
|
|
left=gast.Name(
|
|
|
|
|
id=iter_name,
|
|
|
|
|
ctx=gast.Load(),
|
|
|
|
|
annotation=None,
|
|
|
|
|
type_comment=None),
|
|
|
|
|
op=gast.Add(),
|
|
|
|
|
right=step_node),
|
|
|
|
|
ops=[gast.LtE()],
|
|
|
|
|
comparators=[range_max_node])
|
|
|
|
|
|
|
|
|
|
change_stmt = gast.AugAssign(
|
|
|
|
|
target=gast.Name(
|
|
|
|
|
id=iter_name,
|
|
|
|
|
ctx=gast.Store(),
|
|
|
|
|
annotation=None,
|
|
|
|
|
type_comment=None),
|
|
|
|
|
op=gast.Add(),
|
|
|
|
|
value=step_node)
|
|
|
|
|
|
|
|
|
|
return init_stmt, cond_stmt, change_stmt
|
|
|
|
|
|
|
|
|
|
def get_for_stmt_nodes(self, node):
|
|
|
|
|
# TODO: consider for - else in python
|
|
|
|
|
if not self.name_visitor.is_control_flow_loop(node):
|
|
|
|
|
return [node]
|
|
|
|
|
|
|
|
|
|
# TODO: support non-range case
|
|
|
|
|
range_call_node = self.get_for_range_node(node)
|
|
|
|
|
if range_call_node is None:
|
|
|
|
|
# 1. check whether need to transform
|
|
|
|
|
# NOTE: Current need transform cases:
|
|
|
|
|
# 1). for x in range(VarBase.numpy()[0])
|
|
|
|
|
# 2). for x in VarBase.numpy()
|
|
|
|
|
# 3). for i, x in enumerate(VarBase.numpy())
|
|
|
|
|
if not self.name_visitor.is_control_flow_loop(node):
|
|
|
|
|
return [node]
|
|
|
|
|
|
|
|
|
|
if not isinstance(node.target, gast.Name):
|
|
|
|
|
# 2. get key statements for different cases
|
|
|
|
|
# NOTE: three key statements:
|
|
|
|
|
# 1). init_stmts: list[node], prepare nodes of for loop, may not only one
|
|
|
|
|
# 2). cond_stmt: node, condition node to judge whether continue loop
|
|
|
|
|
# 3). body_stmts: list[node], updated loop body, sometimes we should change
|
|
|
|
|
# the original statement in body, not just append new statement
|
|
|
|
|
current_for_node_parser = ForNodeParser(node)
|
|
|
|
|
stmts_tuple = current_for_node_parser.parse()
|
|
|
|
|
if stmts_tuple is None:
|
|
|
|
|
return [node]
|
|
|
|
|
iter_var_name = node.target.id
|
|
|
|
|
|
|
|
|
|
init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
|
|
|
|
|
iter_var_name, range_call_node.args)
|
|
|
|
|
init_stmts, cond_stmt, body_stmts = stmts_tuple
|
|
|
|
|
|
|
|
|
|
# 3. get original loop vars
|
|
|
|
|
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
|
|
|
|
|
node)
|
|
|
|
|
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
|
|
|
|
|
# we need append new loop var & remove useless loop var
|
|
|
|
|
# 1. for x in var -> x is no need
|
|
|
|
|
# 2. for i, x in enumerate(var) -> x is no need
|
|
|
|
|
if current_for_node_parser.is_for_iter(
|
|
|
|
|
) or current_for_node_parser.is_for_enumerate_iter():
|
|
|
|
|
iter_var_name = current_for_node_parser.iter_var_name
|
|
|
|
|
iter_idx_name = current_for_node_parser.iter_idx_name
|
|
|
|
|
loop_var_names.add(iter_idx_name)
|
|
|
|
|
if iter_var_name not in create_var_names:
|
|
|
|
|
loop_var_names.remove(iter_var_name)
|
|
|
|
|
|
|
|
|
|
# 4. prepare result statement list
|
|
|
|
|
new_stmts = []
|
|
|
|
|
# Python can create variable in loop and use it out of loop, E.g.
|
|
|
|
|
#
|
|
|
|
@ -447,12 +408,13 @@ class LoopTransformer(gast.NodeTransformer):
|
|
|
|
|
if "." not in name:
|
|
|
|
|
new_stmts.append(create_static_variable_gast_node(name))
|
|
|
|
|
|
|
|
|
|
new_stmts.append(init_stmt)
|
|
|
|
|
|
|
|
|
|
# 5. append init statements
|
|
|
|
|
new_stmts.extend(init_stmts)
|
|
|
|
|
# for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
|
|
|
|
|
for name in loop_var_names:
|
|
|
|
|
new_stmts.append(to_static_variable_gast_node(name))
|
|
|
|
|
|
|
|
|
|
# 6. create & append condition function node
|
|
|
|
|
condition_func_node = gast.FunctionDef(
|
|
|
|
|
name=unique_name.generate(FOR_CONDITION_PREFIX),
|
|
|
|
|
args=gast.arguments(
|
|
|
|
@ -480,9 +442,9 @@ class LoopTransformer(gast.NodeTransformer):
|
|
|
|
|
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
|
|
|
|
|
new_stmts.append(condition_func_node)
|
|
|
|
|
|
|
|
|
|
new_body = node.body
|
|
|
|
|
new_body.append(change_stmt)
|
|
|
|
|
new_body.append(
|
|
|
|
|
# 7. create & append loop body function node
|
|
|
|
|
# append return values for loop body
|
|
|
|
|
body_stmts.append(
|
|
|
|
|
gast.Return(value=generate_name_node(
|
|
|
|
|
loop_var_names, ctx=gast.Load())))
|
|
|
|
|
body_func_node = gast.FunctionDef(
|
|
|
|
@ -501,7 +463,7 @@ class LoopTransformer(gast.NodeTransformer):
|
|
|
|
|
kw_defaults=None,
|
|
|
|
|
kwarg=None,
|
|
|
|
|
defaults=[]),
|
|
|
|
|
body=new_body,
|
|
|
|
|
body=body_stmts,
|
|
|
|
|
decorator_list=[],
|
|
|
|
|
returns=None,
|
|
|
|
|
type_comment=None)
|
|
|
|
@ -512,6 +474,7 @@ class LoopTransformer(gast.NodeTransformer):
|
|
|
|
|
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
|
|
|
|
|
new_stmts.append(body_func_node)
|
|
|
|
|
|
|
|
|
|
# 8. create & append while loop node
|
|
|
|
|
while_loop_node = create_while_node(condition_func_node.name,
|
|
|
|
|
body_func_node.name, loop_var_names)
|
|
|
|
|
new_stmts.append(while_loop_node)
|
|
|
|
|