fix reshape input(x) error check on float16. test=develop (#20529)

revert-20712-fix_depthwise_conv
liym27 6 years ago committed by Tao Luo
parent 2384589383
commit 5219efb14f

@ -8524,9 +8524,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
"The type of 'x' in reshape must be Variable, but received %s." %
(type(x)))
if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']:
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in reshape only support float16 in GPU now.")
if convert_dtype(x.dtype) not in [
'float16', 'float32', 'float64', 'int32', 'int64'
]:
raise TypeError(
"The data type of 'x' in reshape must be float32, float64, int32 or int64, "
"The data type of 'x' in reshape must be float16, float32, float64, int32 or int64, "
"but received %s." % (convert_dtype(x.dtype)))
if not isinstance(shape, (list, tuple, Variable)):

@ -238,17 +238,27 @@ class TestReshapeOpError(OpTest):
self.assertRaises(TypeError, test_x_type)
# The x dtype of reshape_op must be float32, float64, int32 or int64.
# The x dtype of reshape_op must be float16, float32, float64, int32 or int64.
def test_x_dtype():
x2 = fluid.layers.data(
name="x2",
shape=[2, 25],
append_batch_size=False,
dtype="float16")
dtype="bool")
fluid.layers.reshape(x2, shape=[2, 5, 5])
self.assertRaises(TypeError, test_x_dtype)
def test_x_dtype_float16():
x_float16 = fluid.layers.data(
name="x_float16",
shape=[2, 25],
append_batch_size=False,
dtype="float16")
fluid.layers.reshape(x_float16, shape=[2, 5, 5])
test_x_dtype_float16()
x3 = fluid.layers.data(
name="x3",
shape=[2, 25],

Loading…
Cancel
Save