[Dy2static] Add for enumerate Variable support (#24398)

* initial test

* for enumerate basic implement, test=develop

* update unittests, test=develop

* refine unittests to adapt new training mode, test=develop

* refactor for node stmts parsing code, test=develop

* self-review & polish details, test=develop
v1.8
Chen Weihang 5 years ago committed by GitHub
parent d980d251f0
commit 03ba5b748d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,6 +19,7 @@ import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['BreakContinueTransformer']
@ -61,87 +62,26 @@ class ForToWhileTransformer(gast.NodeTransformer):
raise ValueError(
"parent_node doesn't contain the loop_node in ForToWhileTransformer")
def get_for_range_node(self, node):
if not isinstance(node.iter, gast.Call):
return None
if not isinstance(node.iter.func, gast.Name):
return None
if node.iter.func.id != "range":
return None
return node.iter
def get_for_args_stmts(self, iter_name, args_list):
'''
Returns 3 gast stmt nodes for argument.
1. Initailize of iterate variable
2. Condition for the loop
3. Statement for changing of iterate variable during the loop
'''
len_range_args = len(args_list)
assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
if len_range_args == 1:
init_stmt = get_constant_variable_node(iter_name, 0)
else:
init_stmt = gast.Assign(
targets=[
gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=args_list[0])
range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
step_node = args_list[2] if len_range_args == 3 else gast.Constant(
value=1, kind=None)
old_cond_stmt = gast.Compare(
left=gast.BinOp(
left=gast.Name(
id=iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
op=gast.Add(),
right=step_node),
ops=[gast.LtE()],
comparators=[range_max_node])
cond_stmt = gast.BoolOp(
op=gast.And(), values=[old_cond_stmt, self.condition_node])
change_stmt = gast.AugAssign(
target=gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
op=gast.Add(),
value=step_node)
return init_stmt, cond_stmt, change_stmt
def get_for_stmt_nodes(self, node):
assert isinstance(
node, gast.For), "Input node is NOT gast.For in get_for_stmt_nodes"
# TODO: support non-range case
range_call_node = self.get_for_range_node(node)
if range_call_node is None:
return [node]
if not isinstance(node.target, gast.Name):
# 1. parse current gast.For node
current_for_node_parser = ForNodeParser(node)
stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None:
return [node]
iter_var_name = node.target.id
init_stmts, cond_stmt, body_stmts = stmts_tuple
init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
iter_var_name, range_call_node.args)
# 2. append break statement
new_cond_stmt = gast.BoolOp(
op=gast.And(), values=[cond_stmt, self.condition_node])
new_body = node.body
new_body.append(change_stmt)
# 3. construct gast.While node
while_node = gast.While(
test=cond_stmt, body=new_body, orelse=node.orelse)
return [init_stmt, while_node]
test=new_cond_stmt, body=body_stmts, orelse=node.orelse)
init_stmts.append(while_node)
return init_stmts
class BreakContinueTransformer(gast.NodeTransformer):

@ -23,10 +23,12 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeParser
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
@ -321,7 +323,7 @@ class LoopTransformer(gast.NodeTransformer):
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of WhileTransformer."
), "Input non-AstNodeWrapper node for the initialization of LoopTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.name_visitor = NameVisitor(self.root)
@ -355,86 +357,45 @@ class LoopTransformer(gast.NodeTransformer):
else:
i += 1
def get_for_range_node(self, node):
if not isinstance(node.iter, gast.Call):
return None
if not isinstance(node.iter.func, gast.Name):
return None
if node.iter.func.id != "range":
return None
return node.iter
def get_for_args_stmts(self, iter_name, args_list):
'''
Returns 3 gast stmt nodes for argument.
1. Initailize of iterate variable
2. Condition for the loop
3. Statement for changing of iterate variable during the loop
NOTE(TODO): Python allows to access iteration variable after loop, such
as "for i in range(10)" will create i = 9 after the loop. But using
current conversion will make i = 10. We should find a way to change it
'''
len_range_args = len(args_list)
assert len_range_args >= 1 and len_range_args <= 3, "range() function takes 1 to 3 arguments"
if len_range_args == 1:
init_stmt = get_constant_variable_node(iter_name, 0)
else:
init_stmt = gast.Assign(
targets=[
gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=args_list[0])
range_max_node = args_list[0] if len_range_args == 1 else args_list[1]
step_node = args_list[2] if len_range_args == 3 else gast.Constant(
value=1, kind=None)
cond_stmt = gast.Compare(
left=gast.BinOp(
left=gast.Name(
id=iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
op=gast.Add(),
right=step_node),
ops=[gast.LtE()],
comparators=[range_max_node])
change_stmt = gast.AugAssign(
target=gast.Name(
id=iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
op=gast.Add(),
value=step_node)
return init_stmt, cond_stmt, change_stmt
def get_for_stmt_nodes(self, node):
# TODO: consider for - else in python
if not self.name_visitor.is_control_flow_loop(node):
return [node]
# TODO: support non-range case
range_call_node = self.get_for_range_node(node)
if range_call_node is None:
# 1. check whether need to transform
# NOTE: Current need transform cases:
# 1). for x in range(VarBase.numpy()[0])
# 2). for x in VarBase.numpy()
# 3). for i, x in enumerate(VarBase.numpy())
if not self.name_visitor.is_control_flow_loop(node):
return [node]
if not isinstance(node.target, gast.Name):
# 2. get key statements for different cases
# NOTE: three key statements:
# 1). init_stmts: list[node], prepare nodes of for loop, may not only one
# 2). cond_stmt: node, condition node to judge whether continue loop
# 3). body_stmts: list[node], updated loop body, sometimes we should change
# the original statement in body, not just append new statement
current_for_node_parser = ForNodeParser(node)
stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None:
return [node]
iter_var_name = node.target.id
init_stmt, cond_stmt, change_stmt = self.get_for_args_stmts(
iter_var_name, range_call_node.args)
init_stmts, cond_stmt, body_stmts = stmts_tuple
# 3. get original loop vars
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
node)
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
# we need append new loop var & remove useless loop var
# 1. for x in var -> x is no need
# 2. for i, x in enumerate(var) -> x is no need
if current_for_node_parser.is_for_iter(
) or current_for_node_parser.is_for_enumerate_iter():
iter_var_name = current_for_node_parser.iter_var_name
iter_idx_name = current_for_node_parser.iter_idx_name
loop_var_names.add(iter_idx_name)
if iter_var_name not in create_var_names:
loop_var_names.remove(iter_var_name)
# 4. prepare result statement list
new_stmts = []
# Python can create variable in loop and use it out of loop, E.g.
#
@ -447,12 +408,13 @@ class LoopTransformer(gast.NodeTransformer):
if "." not in name:
new_stmts.append(create_static_variable_gast_node(name))
new_stmts.append(init_stmt)
# 5. append init statements
new_stmts.extend(init_stmts)
# for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
for name in loop_var_names:
new_stmts.append(to_static_variable_gast_node(name))
# 6. create & append condition function node
condition_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_CONDITION_PREFIX),
args=gast.arguments(
@ -480,9 +442,9 @@ class LoopTransformer(gast.NodeTransformer):
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(condition_func_node)
new_body = node.body
new_body.append(change_stmt)
new_body.append(
# 7. create & append loop body function node
# append return values for loop body
body_stmts.append(
gast.Return(value=generate_name_node(
loop_var_names, ctx=gast.Load())))
body_func_node = gast.FunctionDef(
@ -501,7 +463,7 @@ class LoopTransformer(gast.NodeTransformer):
kw_defaults=None,
kwarg=None,
defaults=[]),
body=new_body,
body=body_stmts,
decorator_list=[],
returns=None,
type_comment=None)
@ -512,6 +474,7 @@ class LoopTransformer(gast.NodeTransformer):
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(body_func_node)
# 8. create & append while loop node
while_loop_node = create_while_node(condition_func_node.name,
body_func_node.name, loop_var_names)
new_stmts.append(while_loop_node)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save