API (Switch) error message enhancement. test=develop (#23459)

* API (Switch) error message enhancement. 

* fix bug: dtype of out in api isfinite is set incorrectly. The dtype should be bool.
revert-23830-2.0-beta
liym27 5 years ago committed by GitHub
parent dc225ed2fc
commit 067134f1b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2173,7 +2173,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
def _error_message(what, arg_name, op_name, right_value, error_value):
error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \
error_message = "{what} of '{arg_name}' in {op_name} must be " \
"{right_value}, but received: {error_value}.".format(
what=what,
arg_name=arg_name,
@ -2309,7 +2309,7 @@ class Switch(object):
OP :ref:`api_fluid_layers_case` is easier to use and is called with less code but does the same thing as ``Switch`` .
Member Functions:
case(cond): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.
@ -2372,6 +2372,10 @@ class Switch(object):
if not self.inside_scope:
raise ValueError("case should be called inside with")
check_variable_and_dtype(
condition, 'condition', ['bool'],
'the member function case of fluid.layers.Switch')
if len(self.pre_not_conditions) == 0:
cond_block = ConditionalBlock([condition], is_scalar_condition=True)
not_cond = logical_not(x=condition)

@ -1154,7 +1154,7 @@ def isfinite(x):
out = fluid.layers.isfinite(var)
"""
helper = LayerHelper("isfinite", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype='bool')
helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out})
return out

@ -26,7 +26,6 @@ from paddle.fluid.framework import default_startup_program
class TestSwitch(unittest.TestCase):
def check_switch(self, value):
x = layers.fill_constant(shape=[1], dtype='float32', value=value)
zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
two_var = layers.fill_constant(shape=[1], dtype='float32', value=2.0)
@ -62,5 +61,34 @@ class TestSwitch(unittest.TestCase):
self.assertEqual(result, expected_result)
class TestSwitchCaseError(unittest.TestCase):
def test_error(self):
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
cond = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
zero_var = layers.fill_constant(
shape=[1], dtype='float32', value=0.0)
result = layers.create_global_var(
shape=[1], value=-1.0, dtype='float32', persistable=True)
# 1. The type of 'condition' in case must be Variable.
def test_condition_type():
with layers.Switch() as switch:
with switch.case(1):
layers.assign(zero_var, result)
self.assertRaises(TypeError, test_condition_type)
# 2. The dtype of 'condition' in case must be 'bool'.
def test_condition_dtype():
with layers.Switch() as switch:
with switch.case(cond):
layers.assign(zero_var, result)
self.assertRaises(TypeError, test_condition_dtype)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save