|
|
|
@ -384,6 +384,23 @@ class TestApiWhileLoop_Error(unittest.TestCase):
|
|
|
|
|
def body_returns_error_type(i, ten):
|
|
|
|
|
return layers.increment(i)
|
|
|
|
|
|
|
|
|
|
def cond_returns_with_mutable_dict(i, test_dict):
|
|
|
|
|
return i > 0
|
|
|
|
|
|
|
|
|
|
def body_returns_with_mutable_dict(i, test_dict):
|
|
|
|
|
test_dict['new_key'] = layers.fill_constant(
|
|
|
|
|
shape=[1], dtype='int64', value=1)
|
|
|
|
|
return layers.increment(i), test_dict
|
|
|
|
|
|
|
|
|
|
def cond_returns_with_mutable_list(i, test_list):
|
|
|
|
|
return i > 0
|
|
|
|
|
|
|
|
|
|
def body_returns_with_mutable_list(i, test_list):
|
|
|
|
|
test_list.append(
|
|
|
|
|
layers.fill_constant(
|
|
|
|
|
shape=[1], dtype='int64', value=1))
|
|
|
|
|
return layers.increment(i), test_list
|
|
|
|
|
|
|
|
|
|
main_program = Program()
|
|
|
|
|
startup_program = Program()
|
|
|
|
|
with program_guard(main_program, startup_program):
|
|
|
|
@ -451,6 +468,31 @@ class TestApiWhileLoop_Error(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
self.assertRaises(ValueError, value_error_body_returns_error_type)
|
|
|
|
|
|
|
|
|
|
# The length of `output_vars` with mutable value should keep same with `loop_vars`
|
|
|
|
|
def value_error_body_returns_with_mutable_dict():
|
|
|
|
|
test_dict = {
|
|
|
|
|
"int_constant": layers.fill_constant(
|
|
|
|
|
shape=[2, 2], dtype='int64', value=1)
|
|
|
|
|
}
|
|
|
|
|
out = layers.while_loop(cond_returns_with_mutable_dict,
|
|
|
|
|
body_returns_with_mutable_dict,
|
|
|
|
|
[data, test_dict])
|
|
|
|
|
|
|
|
|
|
self.assertRaises(ValueError,
|
|
|
|
|
value_error_body_returns_with_mutable_dict)
|
|
|
|
|
|
|
|
|
|
def value_error_body_returns_with_mutable_list():
|
|
|
|
|
test_list = [
|
|
|
|
|
layers.fill_constant(
|
|
|
|
|
shape=[2, 2], dtype='int64', value=1)
|
|
|
|
|
]
|
|
|
|
|
out = layers.while_loop(cond_returns_with_mutable_list,
|
|
|
|
|
body_returns_with_mutable_list,
|
|
|
|
|
[data, test_list])
|
|
|
|
|
|
|
|
|
|
self.assertRaises(ValueError,
|
|
|
|
|
value_error_body_returns_with_mutable_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|