|
|
|
@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scalar_add(x, y):
|
|
|
|
|
"""Implement `scalar_add`."""
|
|
|
|
|
return x + y
|
|
|
|
@ -164,8 +165,9 @@ hyper_map = C.HyperMap()
|
|
|
|
|
|
|
|
|
|
def mixed_precision_cast(dst_type, x):
|
|
|
|
|
"""Implement `mixed_precision_cast`."""
|
|
|
|
|
|
|
|
|
|
def cast_inner(data):
|
|
|
|
|
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16):
|
|
|
|
|
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64):
|
|
|
|
|
return F.cast(data, dst_type)
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|