Op(brelu) error message enhancement (#23606)

revert-23830-2.0-beta
zhupengyang 5 years ago committed by GitHub
parent df538439f5
commit 17bee1d9a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9179,6 +9179,8 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
#[[ 1. 6.] #[[ 1. 6.]
#[ 1. 10.]] #[ 1. 10.]]
""" """
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'brelu')
helper = LayerHelper('brelu', **locals()) helper = LayerHelper('brelu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(

@ -520,6 +520,20 @@ class TestBRelu(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestBReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.brelu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.brelu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.brelu(x_fp16)
class TestRelu6(TestActivation): class TestRelu6(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "relu6" self.op_type = "relu6"

Loading…
Cancel
Save