[Dy2Stat] Support grammar: for ele in var[idx] (#29541)

Support to transformfor ele in var stms in which var is a slice of Tensor.
revert-31562-mean
liym27 5 years ago committed by GitHub
parent b59b6d7ae6
commit a0b60716f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -882,6 +882,8 @@ class ForNodeVisitor(object):
self.node.iter.func,
gast.Attribute) and self.node.iter.func.attr == 'numpy':
return True
elif isinstance(self.node.iter, gast.Subscript):
return True
else:
return False

@ -159,6 +159,7 @@ def for_enumerate_var_numpy_with_start_continue(x_array):
def for_iter_var(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for x in x_array:
z = z + x
return z
@ -221,6 +222,17 @@ def for_enumerate_var_with_nested_range(x_array):
return x
# 16. for iter var[idx]
@paddle.jit.to_static
def for_iter_var_idx(x_array):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_array = fluid.dygraph.to_variable(x_array)
for x in x_array[0:]:
z = z + x
return z
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
@ -343,6 +355,11 @@ class TestForIterVar(TestForIterVarNumpy):
self.dygraph_func = for_iter_var
class TestForIterVarIdx(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_iter_var_idx
class TestForEnumerateVar(TestForIterVarNumpy):
def set_test_func(self):
self.dygraph_func = for_enumerate_var

Loading…
Cancel
Save