|
|
@ -16,8 +16,10 @@
|
|
|
|
# ============================================================================
|
|
|
|
# ============================================================================
|
|
|
|
"""standard_method"""
|
|
|
|
"""standard_method"""
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
from ...ops import functional as F
|
|
|
|
from ...ops import functional as F
|
|
|
|
from ...ops import operations as P
|
|
|
|
from ...ops import operations as P
|
|
|
|
|
|
|
|
from ...ops.primitive import constexpr
|
|
|
|
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
|
|
|
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
|
|
|
zeros_like, ones_like
|
|
|
|
zeros_like, ones_like
|
|
|
|
from ...ops.composite.base import _append
|
|
|
|
from ...ops.composite.base import _append
|
|
|
@ -102,11 +104,44 @@ def bool_(x):
|
|
|
|
return x.__bool__()
|
|
|
|
return x.__bool__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_bool(x):
|
|
|
|
def while_cond(x):
|
|
|
|
"""return immedate x, x is a tensor of bool value"""
|
|
|
|
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
|
|
|
|
|
|
|
|
if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)):
|
|
|
|
|
|
|
|
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
|
|
|
|
|
|
|
if is_cond:
|
|
|
|
|
|
|
|
return F.cast(x, mstype.bool_)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
|
|
|
def check_is_tensor_bool_cond(shp):
|
|
|
|
|
|
|
|
"""check if tensor is a bool condition"""
|
|
|
|
|
|
|
|
if shp in ((), (1,)):
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
raise ValueError("tensor as bool condition, its shape should be () or (1,), but got ", shp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
|
|
|
def const_tensor_to_bool(x):
|
|
|
|
|
|
|
|
"""convert bool tensor to bool condition"""
|
|
|
|
|
|
|
|
if x is None:
|
|
|
|
|
|
|
|
raise ValueError("Only constant tensor bool can be converted to bool")
|
|
|
|
|
|
|
|
x = x.asnumpy()
|
|
|
|
|
|
|
|
if x.shape not in ((), (1,)):
|
|
|
|
|
|
|
|
raise ValueError("Tensor to bool should input shape () or (1), but got ", x.shape)
|
|
|
|
|
|
|
|
if x.shape == ():
|
|
|
|
|
|
|
|
value = bool(x)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
value = bool(x[0])
|
|
|
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_bool(x):
|
|
|
|
|
|
|
|
"""tensor as conditon, if is constant, return immediate bool value"""
|
|
|
|
|
|
|
|
is_cond = check_is_tensor_bool_cond(F.shape(x))
|
|
|
|
|
|
|
|
if is_cond and F.isconstant(x):
|
|
|
|
|
|
|
|
return const_tensor_to_bool(x)
|
|
|
|
|
|
|
|
return F.cast(x, mstype.bool_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def and_(x, y):
|
|
|
|
def and_(x, y):
|
|
|
|
"""Implementation of `and` (`&`)."""
|
|
|
|
"""Implementation of `and` (`&`)."""
|
|
|
|
return x.__and__(y)
|
|
|
|
return x.__and__(y)
|
|
|
|