|
|
|
@ -91,22 +91,27 @@ class IfElseTransformer(gast.NodeTransformer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
def __init__(self, end_node=None):
|
|
|
|
|
def __init__(self, after_node=None, end_node=None):
|
|
|
|
|
# The start node (exclusive) of the visitor
|
|
|
|
|
self.after_node = after_node
|
|
|
|
|
# The terminate node of the visitor.
|
|
|
|
|
self.end_node = end_node
|
|
|
|
|
# Dict to store the names and ctxs of vars.
|
|
|
|
|
self.name_ids = defaultdict(list)
|
|
|
|
|
# List of current visited nodes
|
|
|
|
|
self.ancestor_nodes = []
|
|
|
|
|
# Available only when end_node is set.
|
|
|
|
|
self._is_finished = False
|
|
|
|
|
# True when in range (after_node, end_node).
|
|
|
|
|
self._in_range = after_node is None
|
|
|
|
|
self._candidate_ctxs = (gast.Store, gast.Load, gast.Param)
|
|
|
|
|
self._def_func_names = set()
|
|
|
|
|
|
|
|
|
|
def visit(self, node):
|
|
|
|
|
"""Visit a node."""
|
|
|
|
|
if node == self.end_node or self._is_finished:
|
|
|
|
|
self._is_finished = True
|
|
|
|
|
if self.after_node is not None and node == self.after_node:
|
|
|
|
|
self._in_range = True
|
|
|
|
|
return
|
|
|
|
|
if node == self.end_node:
|
|
|
|
|
self._in_range = False
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.ancestor_nodes.append(node)
|
|
|
|
@ -137,18 +142,19 @@ class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
In above two cases, we should consider to manage the scope of vars to parsing
|
|
|
|
|
the arguments and returned vars correctly.
|
|
|
|
|
"""
|
|
|
|
|
if not self.end_node:
|
|
|
|
|
if not self._in_range or not self.end_node:
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
return
|
|
|
|
|
else:
|
|
|
|
|
before_if_name_ids = copy.deepcopy(self.name_ids)
|
|
|
|
|
body_name_ids = self._visit_child(node.body)
|
|
|
|
|
# If traversal process stops early in `if.body`, return the currently seen name_ids.
|
|
|
|
|
if self._is_finished:
|
|
|
|
|
if not self._in_range:
|
|
|
|
|
self._update_name_ids(before_if_name_ids)
|
|
|
|
|
else:
|
|
|
|
|
else_name_ids = self._visit_child(node.orelse)
|
|
|
|
|
# If traversal process stops early in `if.orelse`, return the currently seen name_ids.
|
|
|
|
|
if self._is_finished:
|
|
|
|
|
if not self._in_range:
|
|
|
|
|
self._update_name_ids(before_if_name_ids)
|
|
|
|
|
else:
|
|
|
|
|
# Blocks the vars in `if.body` and only inserts the vars both created in 'if/else' branch
|
|
|
|
@ -161,10 +167,13 @@ class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
self.name_ids = before_if_name_ids
|
|
|
|
|
|
|
|
|
|
def visit_Attribute(self, node):
|
|
|
|
|
if not self._is_call_func_name_node(node):
|
|
|
|
|
if not self._in_range or not self._is_call_func_name_node(node):
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
|
|
def visit_Name(self, node):
|
|
|
|
|
if not self._in_range:
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
return
|
|
|
|
|
blacklist = {'True', 'False', 'None'}
|
|
|
|
|
if node.id in blacklist: return
|
|
|
|
|
if node.id in self._def_func_names:
|
|
|
|
@ -174,11 +183,17 @@ class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
self.name_ids[node.id].append(node.ctx)
|
|
|
|
|
|
|
|
|
|
def visit_Assign(self, node):
|
|
|
|
|
if not self._in_range:
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
return
|
|
|
|
|
# Visit `value` firstly.
|
|
|
|
|
node._fields = ('value', 'targets')
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
|
|
def visit_FunctionDef(self, node):
|
|
|
|
|
if not self._in_range:
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
return
|
|
|
|
|
self._def_func_names.add(node.name)
|
|
|
|
|
if not self.end_node:
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
@ -187,7 +202,7 @@ class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
self.name_ids = defaultdict(list)
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
|
|
|
|
|
if self._is_finished:
|
|
|
|
|
if not self._in_range:
|
|
|
|
|
self._update_name_ids(before_name_ids)
|
|
|
|
|
else:
|
|
|
|
|
self.name_ids = before_name_ids
|
|
|
|
@ -235,11 +250,13 @@ class NameVisitor(gast.NodeVisitor):
|
|
|
|
|
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_name_ids(nodes, end_node=None):
|
|
|
|
|
def get_name_ids(nodes, after_node=None, end_node=None):
|
|
|
|
|
"""
|
|
|
|
|
Return all ast.Name.id of python variable in nodes.
|
|
|
|
|
Return all ast.Name.id of python variable in nodes range from
|
|
|
|
|
(after_node, end_node) exclusively. If after_node or end_node is None, the
|
|
|
|
|
range is unlimited.
|
|
|
|
|
"""
|
|
|
|
|
name_visitor = NameVisitor(end_node)
|
|
|
|
|
name_visitor = NameVisitor(after_node, end_node)
|
|
|
|
|
for node in nodes:
|
|
|
|
|
name_visitor.visit(node)
|
|
|
|
|
return name_visitor.name_ids
|
|
|
|
@ -434,20 +451,8 @@ def transform_if_else(node, root):
|
|
|
|
|
parent_name_ids = get_name_ids([root], end_node=node)
|
|
|
|
|
body_name_ids = get_name_ids(node.body)
|
|
|
|
|
orelse_name_ids = get_name_ids(node.orelse)
|
|
|
|
|
|
|
|
|
|
# Get after_ifelse_name_ids, which means used var names after If.body and If.orelse node.
|
|
|
|
|
after_ifelse_name_ids = defaultdict(list)
|
|
|
|
|
all_name_ids = get_name_ids([root])
|
|
|
|
|
for name in all_name_ids:
|
|
|
|
|
before_var_names_ids = parent_name_ids.get(name, []) + \
|
|
|
|
|
body_name_ids.get(name, []) + orelse_name_ids.get(name, [])
|
|
|
|
|
# Note: context of node.Name like gast.Load is a concrete object which has unique id different from other gast.Load
|
|
|
|
|
# E.g. ctx of `x` can be [<gast.Load object at 0x142a33c90>, <gast.Load object at 0x142a51950>, <gast.Param object at 0x1407d8250>]
|
|
|
|
|
after_var_names_ids = [
|
|
|
|
|
ctx for ctx in all_name_ids[name] if ctx not in before_var_names_ids
|
|
|
|
|
]
|
|
|
|
|
if after_var_names_ids:
|
|
|
|
|
after_ifelse_name_ids[name] = after_var_names_ids
|
|
|
|
|
after_ifelse_name_ids = get_name_ids([root], after_node=node)
|
|
|
|
|
|
|
|
|
|
return_name_ids, modified_name_ids_from_parent, new_vars_to_create = parse_cond_return(
|
|
|
|
|
parent_name_ids, body_name_ids, orelse_name_ids, after_ifelse_name_ids)
|
|
|
|
|