!8947 [MixedPrecision]Add float64 of mixed_precision_cast

From: @chenfei52
Reviewed-by: @zhunaipan,@zh_qh
Signed-off-by: @zh_qh
pull/8947/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e682bfffdf

@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
def scalar_add(x, y):
"""Implement `scalar_add`."""
return x + y
@ -164,8 +165,9 @@ hyper_map = C.HyperMap()
def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
def cast_inner(data):
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16):
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64):
return F.cast(data, dst_type)
return data

Loading…
Cancel
Save