fix bug of loop_vars in loop_transformer.test=develop (#23180)

revert-23830-2.0-beta
liym27 5 years ago committed by GitHub
parent ebe4eab985
commit af92630666
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -169,7 +169,7 @@ class NameVisitor(gast.NodeVisitor):
if self._is_call_func_name_node(node):
self.generic_visit(node)
return
if node.id == "False" or node.id == "True":
if node.id == "False" or node.id == "True" or node.id == "None":
self.generic_visit(node)
return
@ -187,7 +187,6 @@ class NameVisitor(gast.NodeVisitor):
def visit_Attribute(self, node):
if self._is_call_func_name_node(node):
return
attr_full_name = get_attribute_full_name(node)
self.current_seen_vars.add(node)
for loop_node in self.current_loop:

@ -35,6 +35,17 @@ def while_loop_dyfunc(x):
return i
def while_loop_dyfunc_with_none(x):
i = fluid.dygraph.to_variable(x)\
if x is not None \
else fluid.dygraph.to_variable(x+1)
flag = 1
while x < 10:
i = i + x if flag is not None else x + i
x = x + 1
return i
def for_loop_dyfunc(max_len):
for i in range(max_len):
ret = fluid.layers.zeros(shape=[1], dtype='float32')
@ -58,9 +69,14 @@ def var_create_in_for_loop(max_len):
class TestNameVisitor(unittest.TestCase):
def setUp(self):
self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc]
self.loop_var_names = [set(["i", "x"]), set(["i", "ret", "max_len"])]
self.create_var_names = [set(), set(["ret"])]
self.loop_funcs = [
while_loop_dyfunc, for_loop_dyfunc, while_loop_dyfunc_with_none
]
self.loop_var_names = [
set(["i", "x"]), set(["i", "ret", "max_len"]),
set(["i", "x", "flag"])
]
self.create_var_names = [set(), set(["ret"]), set()]
def test_loop_vars(self):
for i in range(len(self.loop_funcs)):
@ -115,6 +131,11 @@ class TestTransformWhileLoop(unittest.TestCase):
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestTransformWhileLoopWithNone(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_dyfunc_with_none
class TestWhileLoopBoolOp(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_bool_op

Loading…
Cancel
Save