|
|
|
@ -388,20 +388,20 @@ class IfConditionVisitor(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
def __init__(self, node_black_set=None):
|
|
|
|
|
# Set of nodes that will not be visited.
|
|
|
|
|
self.node_black_set = node_black_set or set()
|
|
|
|
|
def __init__(self, end_node=None):
|
|
|
|
|
# The terminate node of the visitor.
|
|
|
|
|
self.end_node = end_node
|
|
|
|
|
# Dict to store the names and ctxs of vars.
|
|
|
|
|
self.name_ids = defaultdict(list)
|
|
|
|
|
# List of current visited nodes
|
|
|
|
|
self.ancestor_nodes = []
|
|
|
|
|
# Available only when node_black_set is set.
|
|
|
|
|
# Available only when end_node is set.
|
|
|
|
|
self._is_finished = False
|
|
|
|
|
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
|
|
|
|
|
|
|
|
|
|
def visit(self, node):
|
|
|
|
|
"""Visit a node."""
|
|
|
|
|
if node in self.node_black_set or self._is_finished:
|
|
|
|
|
if node == self.end_node or self._is_finished:
|
|
|
|
|
self._is_finished = True
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
@ -433,21 +433,28 @@ class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
In above two cases, we should consider to manage the scope of vars to parsing
|
|
|
|
|
the arguments and returned vars correctly.
|
|
|
|
|
"""
|
|
|
|
|
before_if_name_ids = copy.deepcopy(self.name_ids)
|
|
|
|
|
body_name_ids = self._visit_child(node.body)
|
|
|
|
|
# If the traversal process stops early, just return the name_ids that have been seen.
|
|
|
|
|
if self._is_finished:
|
|
|
|
|
for name_id, ctxs in before_if_name_ids.items():
|
|
|
|
|
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
|
|
|
|
|
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
|
|
|
|
|
# into name_ids.
|
|
|
|
|
if not self.end_node:
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
else:
|
|
|
|
|
else_name_ids = self._visit_child(node.orelse)
|
|
|
|
|
new_name_ids = self._find_new_name_ids(body_name_ids, else_name_ids)
|
|
|
|
|
for new_name_id in new_name_ids:
|
|
|
|
|
before_if_name_ids[new_name_id].append(gast.Store())
|
|
|
|
|
|
|
|
|
|
self.name_ids = before_if_name_ids
|
|
|
|
|
before_if_name_ids = copy.deepcopy(self.name_ids)
|
|
|
|
|
body_name_ids = self._visit_child(node.body)
|
|
|
|
|
# If traversal process stops early in `if.body`, return the currently seen name_ids.
|
|
|
|
|
if self._is_finished:
|
|
|
|
|
self._update_name_ids(before_if_name_ids)
|
|
|
|
|
else:
|
|
|
|
|
else_name_ids = self._visit_child(node.orelse)
|
|
|
|
|
# If traversal process stops early in `if.orelse`, return the currently seen name_ids.
|
|
|
|
|
if self._is_finished:
|
|
|
|
|
self._update_name_ids(before_if_name_ids)
|
|
|
|
|
else:
|
|
|
|
|
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
|
|
|
|
|
# into name_ids.
|
|
|
|
|
new_name_ids = self._find_new_name_ids(body_name_ids,
|
|
|
|
|
else_name_ids)
|
|
|
|
|
for new_name_id in new_name_ids:
|
|
|
|
|
before_if_name_ids[new_name_id].append(gast.Store())
|
|
|
|
|
|
|
|
|
|
self.name_ids = before_if_name_ids
|
|
|
|
|
|
|
|
|
|
def visit_Attribute(self, node):
|
|
|
|
|
if not self._is_call_func_name_node(node):
|
|
|
|
@ -463,6 +470,19 @@ class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
node._fields = ('value', 'targets')
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
|
|
def visit_FunctionDef(self, node):
|
|
|
|
|
if not self.end_node:
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
else:
|
|
|
|
|
before_name_ids = copy.deepcopy(self.name_ids)
|
|
|
|
|
self.name_ids = defaultdict(list)
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
|
|
if self._is_finished:
|
|
|
|
|
self._update_name_ids(before_name_ids)
|
|
|
|
|
else:
|
|
|
|
|
self.name_ids = before_name_ids
|
|
|
|
|
|
|
|
|
|
def visit_Return(self, node):
|
|
|
|
|
# Ignore the vars in return
|
|
|
|
|
return
|
|
|
|
@ -505,12 +525,16 @@ class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def _update_name_ids(self, new_name_ids):
|
|
|
|
|
for name_id, ctxs in new_name_ids.items():
|
|
|
|
|
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_name_ids(nodes, node_black_set=None):
|
|
|
|
|
def get_name_ids(nodes, end_node=None):
|
|
|
|
|
"""
|
|
|
|
|
Return all ast.Name.id of python variable in nodes.
|
|
|
|
|
"""
|
|
|
|
|
name_visitor = NameVisitor(node_black_set)
|
|
|
|
|
name_visitor = NameVisitor(end_node)
|
|
|
|
|
for node in nodes:
|
|
|
|
|
name_visitor.visit(node)
|
|
|
|
|
return name_visitor.name_ids
|
|
|
|
@ -611,7 +635,7 @@ def transform_if_else(node, root):
|
|
|
|
|
"""
|
|
|
|
|
Transform ast.If into control flow statement of Paddle static graph.
|
|
|
|
|
"""
|
|
|
|
|
parent_name_ids = get_name_ids([root], node_black_set=[node])
|
|
|
|
|
parent_name_ids = get_name_ids([root], end_node=node)
|
|
|
|
|
if_name_ids = get_name_ids(node.body)
|
|
|
|
|
else_name_ids = get_name_ids(node.orelse)
|
|
|
|
|
|
|
|
|
|