fix sign op input error check on float16 (#20472)

fix sign op input error check
test=develop
revert-20712-fix_depthwise_conv
wawltor 6 years ago committed by GitHub
parent 101e20e92e
commit eb526e3f08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16120,9 +16120,12 @@ def sign(x):
"The type of 'x' in sign_op must be Variable or numpy.ndarray, but received %s."
% (type(x)))
if convert_dtype(x.dtype) not in ['float32', 'float64']:
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in sign_op only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in sign_op must be float32 or float64, but received %s."
"The data type of 'x' in sign_op must be float16, float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=x.dtype)

@ -42,10 +42,13 @@ class TestSignOpError(OpTest):
# The input type of sign_op must be Variable or numpy.ndarray.
input1 = 12
self.assertRaises(TypeError, fluid.layers.sign, input1)
# The input dtype of sign_op must be float32, float64.
# The input dtype of sign_op must be float16, float32, float64.
input2 = fluid.layers.data(
name='input2', shape=[12, 10], dtype="int32")
self.assertRaises(TypeError, fluid.layers.sign, input2)
input3 = fluid.layers.data(
name='input3', shape=[4], dtype="float16")
fluid.layers.sign(input3)
if __name__ == "__main__":

Loading…
Cancel
Save