|
|
|
@ -26,7 +26,6 @@ from paddle.fluid.framework import default_startup_program
|
|
|
|
|
class TestSwitch(unittest.TestCase):
|
|
|
|
|
def check_switch(self, value):
|
|
|
|
|
x = layers.fill_constant(shape=[1], dtype='float32', value=value)
|
|
|
|
|
|
|
|
|
|
zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
|
|
|
|
|
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
|
|
|
|
|
two_var = layers.fill_constant(shape=[1], dtype='float32', value=2.0)
|
|
|
|
@ -62,5 +61,34 @@ class TestSwitch(unittest.TestCase):
|
|
|
|
|
self.assertEqual(result, expected_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSwitchCaseError(unittest.TestCase):
|
|
|
|
|
def test_error(self):
|
|
|
|
|
main_program = framework.Program()
|
|
|
|
|
startup_program = framework.Program()
|
|
|
|
|
with framework.program_guard(main_program, startup_program):
|
|
|
|
|
cond = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
|
|
|
|
|
zero_var = layers.fill_constant(
|
|
|
|
|
shape=[1], dtype='float32', value=0.0)
|
|
|
|
|
|
|
|
|
|
result = layers.create_global_var(
|
|
|
|
|
shape=[1], value=-1.0, dtype='float32', persistable=True)
|
|
|
|
|
|
|
|
|
|
# 1. The type of 'condition' in case must be Variable.
|
|
|
|
|
def test_condition_type():
|
|
|
|
|
with layers.Switch() as switch:
|
|
|
|
|
with switch.case(1):
|
|
|
|
|
layers.assign(zero_var, result)
|
|
|
|
|
|
|
|
|
|
self.assertRaises(TypeError, test_condition_type)
|
|
|
|
|
|
|
|
|
|
# 2. The dtype of 'condition' in case must be 'bool'.
|
|
|
|
|
def test_condition_dtype():
|
|
|
|
|
with layers.Switch() as switch:
|
|
|
|
|
with switch.case(cond):
|
|
|
|
|
layers.assign(zero_var, result)
|
|
|
|
|
|
|
|
|
|
self.assertRaises(TypeError, test_condition_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|