Op(relu) error message enhancement (#23510)

revert-23830-2.0-beta
zhupengyang 5 years ago committed by GitHub
parent 5d970b586b
commit 7b648ad1a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8201,6 +8201,8 @@ def relu(x, name=None):
if in_dygraph_mode():
return core.ops.relu(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu')
inputs = {'X': [x]}
helper = LayerHelper('relu', **locals())
dtype = helper.input_dtype(input_param_name='x')

@ -431,6 +431,20 @@ class TestRelu(TestActivation):
self.check_grad(['X'], 'Out')
class TestReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.sqrt, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.relu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.relu(x_fp16)
class TestLeakyRelu(TestActivation):
def setUp(self):
self.op_type = "leaky_relu"

Loading…
Cancel
Save