|
|
|
@ -2380,9 +2380,13 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The type of 'input' in softmax must be Variable, but received %s" %
|
|
|
|
|
(type(input)))
|
|
|
|
|
if convert_dtype(input.dtype) not in ['float32', 'float64']:
|
|
|
|
|
if convert_dtype(input.dtype) in ['float16']:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The data type of 'input' in softmax only support float16 in GPU now."
|
|
|
|
|
)
|
|
|
|
|
if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The data type of 'input' in softmax must be float32 or float64, but received %s."
|
|
|
|
|
"The data type of 'input' in softmax must be float16, float32 or float64, but received %s."
|
|
|
|
|
% (convert_dtype(input.dtype)))
|
|
|
|
|
|
|
|
|
|
dtype = helper.input_dtype()
|
|
|
|
|