diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index 3c7b91e62f..6b9ee9cbbe 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -117,15 +117,16 @@ class NameVisitor(gast.NodeVisitor): var_node.ctx) in_loop_vars = set(in_loop_vars_list) + in_loop_vars = self._remove_unnecessary_vars(in_loop_vars, node) in_loop_name_strs = self._var_nodes_to_names(in_loop_vars) before_loop_body_vars = self.before_loop_body_vars[node] - before_loop_body_vars = self._remove_target_vars_of_for( + before_loop_body_vars = self._remove_unnecessary_vars( before_loop_body_vars, node) before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars) after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars - after_loop_vars = self._remove_target_vars_of_for(after_loop_vars, node) + after_loop_vars = self._remove_unnecessary_vars(after_loop_vars, node) after_loop_name_strs = self._var_nodes_to_names(after_loop_vars, read_context) condition_vars = self.condition_vars[node] @@ -138,7 +139,6 @@ class NameVisitor(gast.NodeVisitor): for var in in_loop_vars: wrapper = self.node_to_wrapper_map[var] name_to_type[self._var_node_to_name(var)] = wrapper.node_var_type - for name in in_loop_name_strs: if name in before_loop_name_strs: # If a variable is used in loop and created before loop @@ -296,47 +296,83 @@ class NameVisitor(gast.NodeVisitor): return parent_node return None - def _remove_target_vars_of_for(self, before_or_after_loop_vars, loop_node): + def _remove_unnecessary_vars(self, loop_vars, loop_node): """ - Remove target vars of gast.For from before_loop_vars or after_loop_vars. - :param before_or_after_loop_vars: before_loop_vars or after_loop_vars of loop_node. + Remove unnecessary vars from before_loop_vars, after_loop_vars or in_loop_vars about loop_node. + 1. Remove target vars of gast.For from before_loop_vars or after_loop_vars. + 2. Remove vars only in gast.comprehension. + :param loop_vars: before_loop_vars, after_loop_vars or in_loop_vars of loop_node. :param loop_node: Current loop node. """ - removed_vars = set() - for name_node in before_or_after_loop_vars: + vars_of_list_generator = set() + target_vars_of_for_node = set() + + for name_node in loop_vars: if not isinstance(name_node, gast.Name): continue parent_node = self._get_parent_node(name_node) - # NOTE: gast.For.target can be gast.Tuple. - # For example: `for i, j in enumerate(x)` has two target vars: i and j + # NOTE: gast.For.target or gast.comprehension.target can be gast.Tuple. + # For examples: + # 1) `for i, j in enumerate(x)` has two target vars: i and j + # 2) `[x for x,y in array]` has two target vars: x and y if isinstance(parent_node, gast.Tuple): parent_node = self._get_parent_node(parent_node) - if isinstance(parent_node, - gast.For) and parent_node is not loop_node: + # 1. Get vars only in gast.comprehension. + # For examples: + # 1) [x for x,y in array] -> x, x, y + # 2) [f(x) for x in array] -> x + # 3) [func(x, y) for x in array] -> x, x + if isinstance(parent_node, gast.comprehension): + # 1.1 target vars in list/set comprehensions target_node = parent_node.target - if isinstance(target_node, gast.Tuple): target_vars = target_node.elts else: target_vars = [target_node] - if name_node in target_vars: - removed_vars.add(name_node) + vars_of_list_generator = vars_of_list_generator | set( + target_vars) + + # 1.2 vars from target vars used in elt_node + target_var_names = {var.id for var in target_vars} + listcomp_node = self._get_parent_node(parent_node) + elt_node = listcomp_node.elt + if isinstance(elt_node, gast.Name): + if elt_node.id in target_var_names: + vars_of_list_generator.add(elt_node) + for child_node in gast.walk(elt_node): + if isinstance(child_node, gast.Name): + if child_node.id in target_var_names: + vars_of_list_generator.add(child_node) + + # 2. Get target vars or vars from target vars used in for-loop. + elif isinstance(parent_node, + gast.For) and parent_node is not loop_node: + # 2.1 target vars in gast.For node. + target_node = parent_node.target + if isinstance(target_node, gast.Tuple): + target_vars = target_node.elts + else: + target_vars = [target_node] - removed_vars_name_strs = {var.id for var in removed_vars} + target_vars_of_for_node = target_vars_of_for_node | set( + target_vars) - for var in before_or_after_loop_vars: + # 2.2 vars from target vars used in for-loop + target_vars_name_strs = {var.id for var in target_vars_of_for_node} + for var in loop_vars: if not isinstance(var, gast.Name): continue - if var.id in removed_vars_name_strs and var not in self.condition_vars[ + if var.id in target_vars_name_strs and var not in self.condition_vars[ loop_node]: - removed_vars.add(var) + target_vars_of_for_node.add(var) - return before_or_after_loop_vars - removed_vars + removed_vars = target_vars_of_for_node | vars_of_list_generator + return loop_vars - removed_vars class LoopTransformer(gast.NodeTransformer): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 3f47ea4cb1..bf9b579b68 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -169,15 +169,28 @@ def nested_for_loop_dyfunc(): return b +def for_loop_dufunc_with_listcomp(array): + a = 1 + for j in range(array): + res = [x + a for x in array] + res = [i for i in array] + x = 1 + b = [i for i in array] + print(x) + return res + + class TestNameVisitor(unittest.TestCase): def setUp(self): self.loop_funcs = [ - while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none + while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none, + for_loop_dufunc_with_listcomp ] self.loop_var_names = [ - set(["i", "x"]), set(["i", "ret", "max_len"]), set(["i", "x"]) + set(["i", "x"]), set(["i", "ret", "max_len"]), set(["i", "x"]), + set(["j", "array", "res", "x"]) ] - self.create_var_names = [set(), set(["ret"]), set()] + self.create_var_names = [set(), set(["ret"]), set(), set(["res", "x"])] self.nested_for_loop_func = nested_for_loop_dyfunc @@ -211,7 +224,6 @@ class TestNameVisitor(unittest.TestCase): if isinstance(node, (gast.While, gast.For)): loop_var_names, create_var_names = name_visitor.get_loop_var_names( node) - # print(loop_var_names) self.assertEqual(loop_var_names, self.loop_var_names[i]) self.assertEqual(create_var_names, self.create_var_names[i]) i += 1