[Dy2static] Add for iterate or enumerate variable list unittest (#25100)

* add for iter var list, test=develop

* add enumerate unittest, test=develop
fix-sync_batch_norm-hang-in-fleet
Chen Weihang 5 years ago committed by GitHub
parent eb1c0901a6
commit 509d3ec5b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -176,6 +176,40 @@ def for_enumerate_var(x_array):
return y, z
# 13. for iter list[var]
@declarative
def for_iter_var_list(x):
# 1. prepare data, ref test_list.py
x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant(shape=[1], value=5, dtype="int32")
a = []
for i in range(iter_num):
a.append(x + i)
# 2. iter list[var]
y = fluid.layers.fill_constant([1], 'int32', 0)
for x in a:
y = y + x
return y
# 14. for enumerate list[var]
@declarative
def for_enumerate_var_list(x):
# 1. prepare data, ref test_list.py
x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant(shape=[1], value=5, dtype="int32")
a = []
for i in range(iter_num):
a.append(x + i)
# 2. iter list[var]
y = fluid.layers.fill_constant([1], 'int32', 0)
z = fluid.layers.fill_constant([1], 'int32', 0)
for i, x in enumerate(a):
y = y + i
z = z + x
return y, z
class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
@ -303,5 +337,15 @@ class TestForEnumerateVar(TestForIterVarNumpy):
self.dygraph_func = for_enumerate_var
class TestForIterVarList(TestForInRange):
def set_test_func(self):
self.dygraph_func = for_iter_var_list
class TestForEnumerateVarList(TestForInRange):
def set_test_func(self):
self.dygraph_func = for_enumerate_var_list
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save