add float64 of mixed_precision_cast

pull/8947/head
chenfei 4 years ago
parent 68cb63d7f6
commit 369ee9ef9f

@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
def scalar_add(x, y):
"""Implement `scalar_add`."""
return x + y
@ -164,8 +165,9 @@ hyper_map = C.HyperMap()
def mixed_precision_cast(dst_type, x):
"""Implement `mixed_precision_cast`."""
def cast_inner(data):
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16):
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64):
return F.cast(data, dst_type)
return data

Loading…
Cancel
Save