update mul_op input data type check test=develop (#20552)

revert-20712-fix_depthwise_conv
lijianshe02 5 years ago committed by GitHub
parent 40effc61af
commit 5c41805dc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14729,13 +14729,19 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
raise TypeError(
"The type of 'y' in mul must be Variable, but received %s" %
(type(y)))
if convert_dtype(x.dtype) not in ['float32', 'float64']:
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in mul only support float16 in GPU now.")
if convert_dtype(y.dtype) in ['float16']:
warnings.warn(
"The data type of 'y' in mul only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in mul must be float32 or float64, but received %s."
"The data type of 'x' in mul must be float16, float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
if convert_dtype(y.dtype) not in ['float32', 'float64']:
if convert_dtype(y.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'y' in softmax must be float32 or float64, but received %s."
"The data type of 'y' in mul must be float16, float32 or float64, but received %s."
% (convert_dtype(y.dtype)))
if name is None:

Loading…
Cancel
Save