enumerate function and enumerate test case added

pull/7838/head
l00591931 4 years ago
parent 4c9f75d13c
commit 6f165ee5e3

@ -30,7 +30,6 @@ trans = P.Transpose()
shape_ = P.Shape()
dtype_ = P.DType()
def all_(x, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.
@ -144,12 +143,16 @@ def bool_(x):
def enumerate_(x, start=0):
"""Enumerate list or tuple."""
"""Enumerate list or tuple or tensor."""
x_type = F.typeof(x)
ret = ()
op_name = "enumerate"
if check_is_tuple_or_list(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
ret = zip(range(start, start + len(x)), x)
if check_is_tuple_or_list_or_tensor(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"):
if check_is_tensor(x_type):
for i in range(x.shape[0]):
ret += ((start + i, x[i]),)
else:
ret = zip(range(start, start + len(x)), x)
return ret
@ -177,11 +180,19 @@ def check_type_same(x_type, base_type):
@constexpr
def check_is_tuple_or_list(x, op_name, arg_name):
def check_is_tensor(x):
"""check whether x is list or tuple."""
if isinstance(x, (mstype.list_type, mstype.tuple_type)):
if isinstance(x, mstype.tensor_type):
return True
return False
@constexpr
def check_is_tuple_or_list_or_tensor(x, op_name, arg_name):
"""check whether x is list or tuple or tensor."""
if isinstance(x, (mstype.list_type, mstype.tuple_type, mstype.tensor_type)):
return True
raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list, but got {x}.")
raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.")
@constexpr

@ -59,23 +59,36 @@ def test_enumerate_tuple_const():
assert net() == (6, 110)
def test_enumerate_tensor_const():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.arange(2 * 3).reshape(2, 3))
def construct(self):
return enumerate(self.value)
net = Net()
net()
def test_enumerate_list_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z):
def construct(self, x, y):
index_sum = 0
value = [x, y, z]
value = [x, y]
ret = ()
for i, j in enumerate(value):
index_sum += i
ret += (j,)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(4))
net = Net()
net(x, x, x)
net(x, x)
def test_enumerate_tuple_parameter():
@ -83,18 +96,36 @@ def test_enumerate_tuple_parameter():
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z):
def construct(self, x, y):
index_sum = 0
value = (x, y, z)
value = (x, y)
ret = ()
for i, j in enumerate(value):
index_sum += i
ret += (j,)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(4))
net = Net()
net(x, x)
def test_enumerate_tensor_parameter():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
index_sum = 0
ret = ()
for i, j in enumerate(x):
index_sum += i
ret += (j,)
return index_sum, ret
x = Tensor(np.arange(2 * 3).reshape(2, 3))
net = Net()
net(x, x, x)
net(x)
def test_enumerate_tuple_const_1():
@ -115,23 +146,59 @@ def test_enumerate_tuple_const_1():
assert net() == (6, 110)
def test_enumerate_tensor_const_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.arange(2*3).reshape(2, 3))
def construct(self):
index_sum = 0
ret = ()
for i in enumerate(self.value):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret
net = Net()
net()
def test_enumerate_tuple_parameter_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z):
def construct(self, x, y):
index_sum = 0
value = (x, y, z)
value = (x, y)
ret = ()
for i in enumerate(value):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(4))
net = Net()
net(x, x)
def test_enumerate_tensor_parameter_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
index_sum = 0
ret = ()
for i in enumerate(x):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret
x = Tensor(np.arange(2 * 3).reshape(2, 3))
net = Net()
net(x, x, x)
net(x)
def test_enumerate_tuple_const_2():
@ -152,38 +219,59 @@ def test_enumerate_tuple_const_2():
assert net() == (10, 110)
def test_enumerate_tensor_const_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.arange(2 * 3).reshape(2, 3))
def construct(self):
index_sum = 0
ret = ()
for i in enumerate(self.value, 1):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret
net = Net()
net()
def test_enumerate_tuple_parameter_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x, y, z):
def construct(self, x, y):
index_sum = 0
value = (x, y, z)
value = (x, y)
ret = ()
for i in enumerate(value, 2):
for i in enumerate(value, 1):
index_sum += i[0]
ret += (i[1],)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(4))
net = Net()
net(x, x, x)
net(x, x)
def test_enumerate_first_input_type_error():
def test_enumerate_tensor_parameter_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
def construct(self, x):
return enumerate(x)
index_sum = 0
ret = ()
for i, j in enumerate(x, 1):
index_sum += i
ret += (j,)
return index_sum, ret
x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
x = Tensor(np.arange(2 * 3).reshape(2, 3))
net = Net()
with pytest.raises(TypeError) as ex:
net(x)
assert "For 'enumerate', the 'first input'" in str(ex.value)
net(x)
def test_enumerate_start_type_error():

Loading…
Cancel
Save