diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 0fd95eb13c..a38135c295 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype + def scalar_add(x, y): """Implement `scalar_add`.""" return x + y @@ -164,8 +165,9 @@ hyper_map = C.HyperMap() def mixed_precision_cast(dst_type, x): """Implement `mixed_precision_cast`.""" + def cast_inner(data): - if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16): + if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64): return F.cast(data, dst_type) return data