|
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -115,6 +116,7 @@ def bool_or(x, y):
|
|
|
|
|
"""Implement `bool_or`."""
|
|
|
|
|
return x or y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vm_compare(*args):
|
|
|
|
|
"""Implement `vm_compare` for tensor."""
|
|
|
|
|
obj_str = args[-1]
|
|
|
|
@ -143,10 +145,12 @@ def list_len(x):
|
|
|
|
|
"""Implement `list_len`."""
|
|
|
|
|
return len(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Depend(value, expr):
|
|
|
|
|
"""Implement `Depend`."""
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# only used in PyNative mode
|
|
|
|
|
def make_ref(key, value, ref):
|
|
|
|
|
return value
|
|
|
|
@ -177,8 +181,12 @@ def stop_gradient(x):
|
|
|
|
|
|
|
|
|
|
hyper_map = C.HyperMap()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mixed_precision_cast(dst_type, x):
|
|
|
|
|
"""Implement `mixed_precision_cast`."""
|
|
|
|
|
def cast_inner(data):
|
|
|
|
|
return F.cast(data, dst_type)
|
|
|
|
|
if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16):
|
|
|
|
|
return F.cast(data, dst_type)
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
return hyper_map(cast_inner, x)
|
|
|
|
|