@ -16,8 +16,10 @@
# ============================================================================
# ============================================================================
""" standard_method """
""" standard_method """
from dataclasses import dataclass
from dataclasses import dataclass
from mindspore . common import dtype as mstype
from . . . ops import functional as F
from . . . ops import functional as F
from . . . ops import operations as P
from . . . ops import operations as P
from . . . ops . primitive import constexpr
from . . . ops . composite import tail , core , MultitypeFuncGraph , env_get , hyper_add , \
from . . . ops . composite import tail , core , MultitypeFuncGraph , env_get , hyper_add , \
zeros_like , ones_like
zeros_like , ones_like
from . . . ops . composite . base import _append
from . . . ops . composite . base import _append
@ -102,11 +104,44 @@ def bool_(x):
return x . __bool__ ( )
return x . __bool__ ( )
def tensor_bool ( x ) :
def while_cond ( x ) :
""" return immedate x, x is a tensor of bool value """
""" For while condtion, if the condition is a tensor, the loop will not be unrolled """
if F . issubclass_ ( F . typeof ( x ) , F . typeof ( mstype . tensor ) ) :
is_cond = check_is_tensor_bool_cond ( F . shape ( x ) )
if is_cond :
return F . cast ( x , mstype . bool_ )
return x
return x
@constexpr
def check_is_tensor_bool_cond ( shp ) :
""" check if tensor is a bool condition """
if shp in ( ( ) , ( 1 , ) ) :
return True
raise ValueError ( " tensor as bool condition, its shape should be () or (1,), but got " , shp )
@constexpr
def const_tensor_to_bool ( x ) :
""" convert bool tensor to bool condition """
if x is None :
raise ValueError ( " Only constant tensor bool can be converted to bool " )
x = x . asnumpy ( )
if x . shape not in ( ( ) , ( 1 , ) ) :
raise ValueError ( " Tensor to bool should input shape () or (1), but got " , x . shape )
if x . shape == ( ) :
value = bool ( x )
else :
value = bool ( x [ 0 ] )
return value
def tensor_bool ( x ) :
""" tensor as conditon, if is constant, return immediate bool value """
is_cond = check_is_tensor_bool_cond ( F . shape ( x ) )
if is_cond and F . isconstant ( x ) :
return const_tensor_to_bool ( x )
return F . cast ( x , mstype . bool_ )
def and_ ( x , y ) :
def and_ ( x , y ) :
""" Implementation of `and` (`&`). """
""" Implementation of `and` (`&`). """
return x . __and__ ( y )
return x . __and__ ( y )