|
|
@ -13772,24 +13772,24 @@ def _elementwise_op(helper):
|
|
|
|
(op_type, type(y)))
|
|
|
|
(op_type, type(y)))
|
|
|
|
if convert_dtype(x.dtype) in ['float16']:
|
|
|
|
if convert_dtype(x.dtype) in ['float16']:
|
|
|
|
warnings.warn(
|
|
|
|
warnings.warn(
|
|
|
|
"The data type of 'x' in batch_norm only support float16 on GPU now."
|
|
|
|
"The data type of 'x' in %s only support float16 on GPU now." %
|
|
|
|
)
|
|
|
|
(op_type))
|
|
|
|
if convert_dtype(y.dtype) in ['float16']:
|
|
|
|
if convert_dtype(y.dtype) in ['float16']:
|
|
|
|
warnings.warn(
|
|
|
|
warnings.warn(
|
|
|
|
"The data type of 'y' in batch_norm only support float16 on GPU now."
|
|
|
|
"The data type of 'y' in %s only support float16 on GPU now." %
|
|
|
|
)
|
|
|
|
(op_type))
|
|
|
|
if convert_dtype(x.dtype) not in [
|
|
|
|
if convert_dtype(x.dtype) not in [
|
|
|
|
'float16', 'float32', 'float64', 'int32', 'int64'
|
|
|
|
'float16', 'float32', 'float64', 'int32', 'int64'
|
|
|
|
]:
|
|
|
|
]:
|
|
|
|
raise TypeError(
|
|
|
|
raise TypeError(
|
|
|
|
"The data type of 'x' in batch_norm must be float16 or float32 or float64 or int32 or int64, but received %s."
|
|
|
|
"The data type of 'x' in %s must be float16 or float32 or float64 or int32 or int64, "
|
|
|
|
% (convert_dtype(x.dtype)))
|
|
|
|
"but received %s." % (op_type, convert_dtype(x.dtype)))
|
|
|
|
if convert_dtype(y.dtype) not in [
|
|
|
|
if convert_dtype(y.dtype) not in [
|
|
|
|
'float16', 'float32', 'float64', 'int32', 'int64'
|
|
|
|
'float16', 'float32', 'float64', 'int32', 'int64'
|
|
|
|
]:
|
|
|
|
]:
|
|
|
|
raise TypeError(
|
|
|
|
raise TypeError(
|
|
|
|
"The data type of 'y' in batch_norm must be float16 or float32 or float64 or int32 or int64, but received %s."
|
|
|
|
"The data type of 'y' in %s must be float16 or float32 or float64 or int32 or int64, "
|
|
|
|
% (convert_dtype(y.dtype)))
|
|
|
|
"but received %s." % (op_type, convert_dtype(y.dtype)))
|
|
|
|
|
|
|
|
|
|
|
|
axis = helper.kwargs.get('axis', -1)
|
|
|
|
axis = helper.kwargs.get('axis', -1)
|
|
|
|
use_mkldnn = helper.kwargs.get('use_mkldnn', False)
|
|
|
|
use_mkldnn = helper.kwargs.get('use_mkldnn', False)
|
|
|
|