|
|
|
@ -95,7 +95,7 @@ class IfElseTransformer(gast.NodeTransformer):
|
|
|
|
|
"""
|
|
|
|
|
self._insert_func_nodes(node)
|
|
|
|
|
|
|
|
|
|
def _insert_func_nodes(self, parent_node):
|
|
|
|
|
def _insert_func_nodes(self, node):
|
|
|
|
|
"""
|
|
|
|
|
Defined `true_func` and `false_func` will be inserted in front of corresponding
|
|
|
|
|
`layers.cond` statement instead of inserting them all into body of parent node.
|
|
|
|
@ -103,13 +103,18 @@ class IfElseTransformer(gast.NodeTransformer):
|
|
|
|
|
For example, `self.var_dict["key"]`. In this case, nested structure of newly
|
|
|
|
|
defined functions is easier to understand.
|
|
|
|
|
"""
|
|
|
|
|
if not (self.new_func_nodes and hasattr(parent_node, 'body')):
|
|
|
|
|
if not self.new_func_nodes:
|
|
|
|
|
return
|
|
|
|
|
idx = len(parent_node.body) - 1
|
|
|
|
|
idx = -1
|
|
|
|
|
if isinstance(node, list):
|
|
|
|
|
idx = len(node) - 1
|
|
|
|
|
elif isinstance(node, gast.AST):
|
|
|
|
|
for _, child in gast.iter_fields(node):
|
|
|
|
|
self._insert_func_nodes(child)
|
|
|
|
|
while idx >= 0:
|
|
|
|
|
child_node = parent_node.body[idx]
|
|
|
|
|
child_node = node[idx]
|
|
|
|
|
if child_node in self.new_func_nodes:
|
|
|
|
|
parent_node.body[idx:idx] = self.new_func_nodes[child_node]
|
|
|
|
|
node[idx:idx] = self.new_func_nodes[child_node]
|
|
|
|
|
idx = idx + len(self.new_func_nodes[child_node]) - 1
|
|
|
|
|
del self.new_func_nodes[child_node]
|
|
|
|
|
else:
|
|
|
|
@ -366,51 +371,133 @@ class IfConditionVisitor(object):
|
|
|
|
|
return new_node, new_assign_nodes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_name_ids(nodes, not_name_set=None, node_black_list=None):
|
|
|
|
|
class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
def __init__(self, node_black_set=None):
|
|
|
|
|
# Set of nodes that will not be visited.
|
|
|
|
|
self.node_black_set = node_black_set or set()
|
|
|
|
|
# Dict to store the names and ctxs of vars.
|
|
|
|
|
self.name_ids = defaultdict(list)
|
|
|
|
|
# List of current visited nodes
|
|
|
|
|
self.ancestor_nodes = []
|
|
|
|
|
# Available only when node_black_set is set.
|
|
|
|
|
self._is_finished = False
|
|
|
|
|
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
|
|
|
|
|
|
|
|
|
|
def visit(self, node):
|
|
|
|
|
"""Visit a node."""
|
|
|
|
|
if node in self.node_black_set or self._is_finished:
|
|
|
|
|
self._is_finished = True
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.ancestor_nodes.append(node)
|
|
|
|
|
method = 'visit_' + node.__class__.__name__
|
|
|
|
|
visitor = getattr(self, method, self.generic_visit)
|
|
|
|
|
ret = visitor(node)
|
|
|
|
|
self.ancestor_nodes.pop()
|
|
|
|
|
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def visit_If(self, node):
|
|
|
|
|
"""
|
|
|
|
|
Return all ast.Name.id of python variable in nodes.
|
|
|
|
|
For nested `if/else`, the created vars are not always visible for parent node.
|
|
|
|
|
In addition, the vars created in `if.body` are not visible for `if.orelse`.
|
|
|
|
|
|
|
|
|
|
Case 1:
|
|
|
|
|
x = 1
|
|
|
|
|
if m > 1:
|
|
|
|
|
res = new_tensor
|
|
|
|
|
res = res + 1 # Error, `res` is not visible here.
|
|
|
|
|
|
|
|
|
|
Case 2:
|
|
|
|
|
if x_tensor > 0:
|
|
|
|
|
res = new_tensor
|
|
|
|
|
else:
|
|
|
|
|
res = res + 1 # Error, `res` is not visible here.
|
|
|
|
|
|
|
|
|
|
In above two cases, we should consider to manage the scope of vars to parsing
|
|
|
|
|
the arguments and returned vars correctly.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(nodes, (list, tuple, set)):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"nodes must be one of list, tuple, set, but received %s" %
|
|
|
|
|
type(nodes))
|
|
|
|
|
if not_name_set is None:
|
|
|
|
|
not_name_set = set()
|
|
|
|
|
|
|
|
|
|
def update(old_dict, new_dict):
|
|
|
|
|
for k, v in new_dict.items():
|
|
|
|
|
old_dict[k].extend(v)
|
|
|
|
|
|
|
|
|
|
name_ids = defaultdict(list)
|
|
|
|
|
for node in nodes:
|
|
|
|
|
if node_black_list and node in node_black_list:
|
|
|
|
|
break
|
|
|
|
|
if isinstance(node, gast.AST):
|
|
|
|
|
# In two case, the ast.Name should be filtered.
|
|
|
|
|
# 1. Function name like `my_func` of my_func(x)
|
|
|
|
|
# 2. api prefix like `fluid` of `fluid.layers.mean`
|
|
|
|
|
if isinstance(node, gast.Return):
|
|
|
|
|
continue
|
|
|
|
|
elif isinstance(node, gast.Call) and isinstance(node.func,
|
|
|
|
|
gast.Name):
|
|
|
|
|
not_name_set.add(node.func.id)
|
|
|
|
|
elif isinstance(node, gast.Attribute) and isinstance(node.value,
|
|
|
|
|
gast.Name):
|
|
|
|
|
not_name_set.add(node.value.id)
|
|
|
|
|
if isinstance(
|
|
|
|
|
node, gast.Name
|
|
|
|
|
) and node.id not in name_ids and node.id not in not_name_set:
|
|
|
|
|
if isinstance(node.ctx, (gast.Store, gast.Load, gast.Param)):
|
|
|
|
|
name_ids[node.id].append(node.ctx)
|
|
|
|
|
before_if_name_ids = copy.deepcopy(self.name_ids)
|
|
|
|
|
body_name_ids = self._visit_child(node.body)
|
|
|
|
|
# If the traversal process stops early, just return the name_ids that have been seen.
|
|
|
|
|
if self._is_finished:
|
|
|
|
|
for name_id, ctxs in before_if_name_ids.items():
|
|
|
|
|
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
|
|
|
|
|
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
|
|
|
|
|
# into name_ids.
|
|
|
|
|
else:
|
|
|
|
|
if isinstance(node, gast.Assign):
|
|
|
|
|
node = copy.copy(node)
|
|
|
|
|
else_name_ids = self._visit_child(node.orelse)
|
|
|
|
|
new_name_ids = self._find_new_name_ids(body_name_ids, else_name_ids)
|
|
|
|
|
for new_name_id in new_name_ids:
|
|
|
|
|
before_if_name_ids[new_name_id].append(gast.Store())
|
|
|
|
|
|
|
|
|
|
self.name_ids = before_if_name_ids
|
|
|
|
|
|
|
|
|
|
def visit_Attribute(self, node):
|
|
|
|
|
if not self._is_call_func_name_node(node):
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
|
|
def visit_Name(self, node):
|
|
|
|
|
if not self._is_call_func_name_node(node):
|
|
|
|
|
if isinstance(node.ctx, self._candidate_ctxs):
|
|
|
|
|
self.name_ids[node.id].append(node.ctx)
|
|
|
|
|
|
|
|
|
|
def visit_Assign(self, node):
|
|
|
|
|
# Visit `value` firstly.
|
|
|
|
|
node._fields = ('value', 'targets')
|
|
|
|
|
for field, value in gast.iter_fields(node):
|
|
|
|
|
value = value if isinstance(value, list) else [value]
|
|
|
|
|
update(name_ids,
|
|
|
|
|
get_name_ids(value, not_name_set, node_black_list))
|
|
|
|
|
return name_ids
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
|
|
def visit_Return(self, node):
|
|
|
|
|
# Ignore the vars in return
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def _visit_child(self, node):
|
|
|
|
|
self.name_ids = defaultdict(list)
|
|
|
|
|
if isinstance(node, list):
|
|
|
|
|
for item in node:
|
|
|
|
|
if isinstance(item, gast.AST):
|
|
|
|
|
self.visit(item)
|
|
|
|
|
elif isinstance(node, gast.AST):
|
|
|
|
|
self.visit(node)
|
|
|
|
|
|
|
|
|
|
return copy.deepcopy(self.name_ids)
|
|
|
|
|
|
|
|
|
|
def _find_new_name_ids(self, body_name_ids, else_name_ids):
|
|
|
|
|
def is_required_ctx(ctxs, required_ctx):
|
|
|
|
|
for ctx in ctxs:
|
|
|
|
|
if isinstance(ctx, required_ctx):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
candidate_name_ids = set(body_name_ids.keys()) & set(else_name_ids.keys(
|
|
|
|
|
))
|
|
|
|
|
store_ctx = gast.Store
|
|
|
|
|
new_name_ids = set()
|
|
|
|
|
for name_id in candidate_name_ids:
|
|
|
|
|
if is_required_ctx(body_name_ids[name_id],
|
|
|
|
|
store_ctx) and is_required_ctx(
|
|
|
|
|
else_name_ids[name_id], store_ctx):
|
|
|
|
|
new_name_ids.add(name_id)
|
|
|
|
|
|
|
|
|
|
return new_name_ids
|
|
|
|
|
|
|
|
|
|
def _is_call_func_name_node(self, node):
|
|
|
|
|
if len(self.ancestor_nodes) > 1:
|
|
|
|
|
assert self.ancestor_nodes[-1] == node
|
|
|
|
|
parent_node = self.ancestor_nodes[-2]
|
|
|
|
|
if isinstance(parent_node, gast.Call) and parent_node.func == node:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_name_ids(nodes, node_black_set=None):
|
|
|
|
|
"""
|
|
|
|
|
Return all ast.Name.id of python variable in nodes.
|
|
|
|
|
"""
|
|
|
|
|
name_visitor = NameVisitor(node_black_set)
|
|
|
|
|
for node in nodes:
|
|
|
|
|
name_visitor.visit(node)
|
|
|
|
|
return name_visitor.name_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
|
|
|
|
@ -508,7 +595,7 @@ def transform_if_else(node, root):
|
|
|
|
|
"""
|
|
|
|
|
Transform ast.If into control flow statement of Paddle static graph.
|
|
|
|
|
"""
|
|
|
|
|
parent_name_ids = get_name_ids([root], node_black_list=[node])
|
|
|
|
|
parent_name_ids = get_name_ids([root], node_black_set=[node])
|
|
|
|
|
if_name_ids = get_name_ids(node.body)
|
|
|
|
|
else_name_ids = get_name_ids(node.orelse)
|
|
|
|
|
|
|
|
|
|