@ -520,3 +520,83 @@ def test_while_in_while():
out = out + 3
return out
while_in_while ( c1 , c2 , c3 , c4 )
def test_tensor_cond ( ) :
class Net ( nn . Cell ) :
def __init__ ( self ) :
super ( Net , self ) . __init__ ( )
self . t = Tensor ( np . array ( 0 , np . bool ) )
self . t1 = Tensor ( np . array ( [ True ] , np . bool ) )
def construct ( self , x , y ) :
t = 0
if self . t :
t = t - x * y
else :
t = t - x / y
if self . t1 :
t = t + x / y
else :
t = t + x * y
return t
x = Tensor ( np . ones ( [ 6 , 8 , 10 ] , np . int32 ) )
y = Tensor ( np . ones ( [ 6 , 8 , 10 ] , np . int32 ) )
net = Net ( )
out = net ( x , y )
def test_tensor_cond_exception ( ) :
class Net ( nn . Cell ) :
def __init__ ( self ) :
super ( Net , self ) . __init__ ( )
self . t = Tensor ( np . array ( [ True , False ] , np . bool ) )
def construct ( self , x , y ) :
t = 0
if self . t :
t = t - x * y
else :
t = t - x / y
return t
x = Tensor ( np . ones ( [ 6 , 8 , 10 ] , np . int32 ) )
y = Tensor ( np . ones ( [ 6 , 8 , 10 ] , np . int32 ) )
net = Net ( )
with pytest . raises ( ValueError ) :
out = net ( x , y )
def test_while_scalar ( ) :
class Net ( nn . Cell ) :
def __init__ ( self ) :
super ( Net , self ) . __init__ ( )
self . x = 10
def construct ( self , x , y ) :
i = 0
t = 0
while ( i < 10 ) :
t = t + x + y
i = i + 1
return t
net = Net ( )
x = Tensor ( np . ones ( [ 6 , 8 , 10 ] , np . int32 ) )
y = Tensor ( np . ones ( [ 6 , 8 , 10 ] , np . int32 ) )
out = net ( x , y )
def test_while_tensor ( ) :
class Net ( nn . Cell ) :
def __init__ ( self ) :
super ( Net , self ) . __init__ ( )
self . t = Tensor ( np . ones ( [ 6 , 8 , 10 ] , np . int32 ) )
self . count = Tensor ( np . array ( [ 10 ] , np . int32 ) )
def construct ( self , x , y ) :
i = 0
t = self . t
while ( i < self . count ) :
t = t + x + y
i = i + 1
return t
net = Net ( )
x = Tensor ( np . ones ( [ 6 , 8 , 10 ] , np . int32 ) )
y = Tensor ( np . ones ( [ 6 , 8 , 10 ] , np . int32 ) )
out = net ( x , y )