From 369ee9ef9f18b9f9a96349d2d02c77c2df92f09f Mon Sep 17 00:00:00 2001 From: chenfei Date: Tue, 24 Nov 2020 15:27:12 +0800 Subject: [PATCH] add float64 of mixed_precision_cast --- mindspore/_extends/builtin_operations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 0fd95eb13c..a38135c295 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -20,6 +20,7 @@ from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype + def scalar_add(x, y): """Implement `scalar_add`.""" return x + y @@ -164,8 +165,9 @@ hyper_map = C.HyperMap() def mixed_precision_cast(dst_type, x): """Implement `mixed_precision_cast`.""" + def cast_inner(data): - if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16): + if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64): return F.cast(data, dst_type) return data