optimize the error information when the input for while op has a wron… (#19872)

* optimize the error information when the input for while op has a wrong shape test=develop
expand_as_op_1
wopeizl 6 years ago committed by GitHub
parent d31c92a2cd
commit e606b1754e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -679,9 +679,11 @@ class While(object):
raise TypeError("condition should be a variable")
assert isinstance(cond, Variable)
if cond.dtype != core.VarDesc.VarType.BOOL:
raise TypeError("condition should be a bool variable")
raise TypeError("condition should be a boolean variable")
if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError("condition should be a bool scalar")
raise TypeError(
"condition expected shape as [], but given shape as {0}.".
format(list(cond.shape)))
self.cond_var = cond
self.is_test = is_test

@ -96,6 +96,16 @@ class TestWhileOp(unittest.TestCase):
fetch_list=[sum_result])
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)
def test_exceptions(self):
i = layers.zeros(shape=[2], dtype='int64')
array_len = layers.fill_constant(shape=[2], dtype='int64', value=1)
cond = layers.less_than(x=i, y=array_len)
with self.assertRaises(TypeError):
layers.While(cond=cond)
cond = layers.cast(cond, dtype='float64')
with self.assertRaises(TypeError):
layers.While(cond=cond)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save