Fix NameVisitor bugs (#22847)

1. copy.deepcopy in NameVisitor should be changed to copy.copy to make hash or set work
2. read_context should be type of gast.Load()/gast.AugLoad(), not gast.Load/gast.AugLoad
revert-22710-feature/integrated_ps_api
Huihuang Zheng 5 years ago committed by GitHub
parent f686310d81
commit 0d463d3bf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -80,11 +80,11 @@ class NameVisitor(gast.NodeVisitor):
return True
def get_loop_var_names(self, node):
assert isinstance(node, gast.While) or isinstance(
while_node, gast.For), "Input node is not gast loop node"
assert isinstance(node, (gast.While,
gast.For)), "Input node is not gast loop node"
loop_var_names = set()
create_var_names = set()
read_context = {type(gast.Load), type(gast.AugLoad)}
read_context = {type(gast.Load()), type(gast.AugLoad())}
in_loop_vars = self.in_loop_vars[node]
in_loop_name_strs = set(name.id for name in in_loop_vars)
@ -114,13 +114,13 @@ class NameVisitor(gast.NodeVisitor):
def visit_For(self, node):
self.current_loop.append(node)
self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars)
self.before_loop_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()
def visit_While(self, node):
self.current_loop.append(node)
self.before_loop_vars[node] = copy.deepcopy(self.current_seen_vars)
self.before_loop_vars[node] = copy.copy(self.current_seen_vars)
self.generic_visit(node)
self.current_loop.pop()

@ -1,4 +1,4 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,7 +21,7 @@ import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
#from paddle.fluid.dygraph.dygraph_to_static import NameVistor
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import NameVisitor
SEED = 2020
np.random.seed(SEED)
@ -37,8 +37,15 @@ def while_loop_dyfunc(x):
class TestNameVisitor(unittest.TestCase):
def test_loop_vars(self):
#TODO
pass
test_func = inspect.getsource(while_loop_dyfunc)
gast_root = gast.parse(test_func)
name_visitor = NameVisitor(gast_root)
for node in gast.walk(gast_root):
if isinstance(node, gast.While):
loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node)
self.assertEqual(loop_var_names, set(["i", "x"]))
self.assertEqual(create_var_names, set())
class TestTransformWhile(unittest.TestCase):

Loading…
Cancel
Save