refine huber loss unittest, test=develop (#24263)

revert-24314-dev/fix_err_msg
huangjun12 5 years ago committed by GitHub
parent 356f5ee220
commit d0b0e27408
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -89,14 +89,17 @@ class TestHuberLossOpError(unittest.TestCase):
xr = fluid.data(name='xr', shape=[None, 6], dtype="float32")
lw = np.random.random((6, 6)).astype("float32")
lr = fluid.data(name='lr', shape=[None, 6], dtype="float32")
self.assertRaises(TypeError, fluid.layers.huber_loss, xw, lr)
self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw)
delta = 1.0
self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw, delta)
self.assertRaises(TypeError, fluid.layers.huber_loss, xw, lr, delta)
# the dtype of input and label must be float32 or float64
xw2 = fluid.data(name='xw2', shape=[None, 6], dtype="int32")
lw2 = fluid.data(name='lw2', shape=[None, 6], dtype="int32")
self.assertRaises(TypeError, fluid.layers.huber_loss, xw2, lr)
self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw2)
self.assertRaises(TypeError, fluid.layers.huber_loss, xw2, lr,
delta)
self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw2,
delta)
if __name__ == '__main__':

Loading…
Cancel
Save