|
|
|
@ -8524,9 +8524,15 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
|
|
|
|
|
"The type of 'x' in reshape must be Variable, but received %s." %
|
|
|
|
|
(type(x)))
|
|
|
|
|
|
|
|
|
|
if convert_dtype(x.dtype) not in ['float32', 'float64', 'int32', 'int64']:
|
|
|
|
|
if convert_dtype(x.dtype) in ['float16']:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The data type of 'x' in reshape only support float16 in GPU now.")
|
|
|
|
|
|
|
|
|
|
if convert_dtype(x.dtype) not in [
|
|
|
|
|
'float16', 'float32', 'float64', 'int32', 'int64'
|
|
|
|
|
]:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The data type of 'x' in reshape must be float32, float64, int32 or int64, "
|
|
|
|
|
"The data type of 'x' in reshape must be float16, float32, float64, int32 or int64, "
|
|
|
|
|
"but received %s." % (convert_dtype(x.dtype)))
|
|
|
|
|
|
|
|
|
|
if not isinstance(shape, (list, tuple, Variable)):
|
|
|
|
|