|
|
|
@ -14729,13 +14729,19 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The type of 'y' in mul must be Variable, but received %s" %
|
|
|
|
|
(type(y)))
|
|
|
|
|
if convert_dtype(x.dtype) not in ['float32', 'float64']:
|
|
|
|
|
if convert_dtype(x.dtype) in ['float16']:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The data type of 'x' in mul only support float16 in GPU now.")
|
|
|
|
|
if convert_dtype(y.dtype) in ['float16']:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The data type of 'y' in mul only support float16 in GPU now.")
|
|
|
|
|
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The data type of 'x' in mul must be float32 or float64, but received %s."
|
|
|
|
|
"The data type of 'x' in mul must be float16, float32 or float64, but received %s."
|
|
|
|
|
% (convert_dtype(x.dtype)))
|
|
|
|
|
if convert_dtype(y.dtype) not in ['float32', 'float64']:
|
|
|
|
|
if convert_dtype(y.dtype) not in ['float16', 'float32', 'float64']:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The data type of 'y' in softmax must be float32 or float64, but received %s."
|
|
|
|
|
"The data type of 'y' in mul must be float16, float32 or float64, but received %s."
|
|
|
|
|
% (convert_dtype(y.dtype)))
|
|
|
|
|
|
|
|
|
|
if name is None:
|
|
|
|
|