diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index c29aacbaef..c68f960519 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -30,7 +30,6 @@ trans = P.Transpose() shape_ = P.Shape() dtype_ = P.DType() - def all_(x, axis=(), keep_dims=False): """ Check all array elements along a given axis evaluate to True. @@ -144,12 +143,16 @@ def bool_(x): def enumerate_(x, start=0): - """Enumerate list or tuple.""" + """Enumerate list or tuple or tensor.""" x_type = F.typeof(x) ret = () op_name = "enumerate" - if check_is_tuple_or_list(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"): - ret = zip(range(start, start + len(x)), x) + if check_is_tuple_or_list_or_tensor(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"): + if check_is_tensor(x_type): + for i in range(x.shape[0]): + ret += ((start + i, x[i]),) + else: + ret = zip(range(start, start + len(x)), x) return ret @@ -177,11 +180,19 @@ def check_type_same(x_type, base_type): @constexpr -def check_is_tuple_or_list(x, op_name, arg_name): +def check_is_tensor(x): """check whether x is list or tuple.""" - if isinstance(x, (mstype.list_type, mstype.tuple_type)): + if isinstance(x, mstype.tensor_type): + return True + return False + + +@constexpr +def check_is_tuple_or_list_or_tensor(x, op_name, arg_name): + """check whether x is list or tuple or tensor.""" + if isinstance(x, (mstype.list_type, mstype.tuple_type, mstype.tensor_type)): return True - raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list, but got {x}.") + raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.") @constexpr diff --git a/tests/ut/python/pipeline/parse/test_enumerate.py b/tests/ut/python/pipeline/parse/test_enumerate.py index c7ce74b913..a1de60fdea 100644 --- a/tests/ut/python/pipeline/parse/test_enumerate.py +++ b/tests/ut/python/pipeline/parse/test_enumerate.py @@ -59,23 +59,36 @@ def test_enumerate_tuple_const(): assert net() == (6, 110) +def test_enumerate_tensor_const(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor(np.arange(2 * 3).reshape(2, 3)) + + def construct(self): + return enumerate(self.value) + + net = Net() + net() + + def test_enumerate_list_parameter(): class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - def construct(self, x, y, z): + def construct(self, x, y): index_sum = 0 - value = [x, y, z] + value = [x, y] ret = () for i, j in enumerate(value): index_sum += i ret += (j,) return index_sum, ret - x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) + x = Tensor(np.arange(4)) net = Net() - net(x, x, x) + net(x, x) def test_enumerate_tuple_parameter(): @@ -83,18 +96,36 @@ def test_enumerate_tuple_parameter(): def __init__(self): super(Net, self).__init__() - def construct(self, x, y, z): + def construct(self, x, y): index_sum = 0 - value = (x, y, z) + value = (x, y) ret = () for i, j in enumerate(value): index_sum += i ret += (j,) return index_sum, ret - x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) + x = Tensor(np.arange(4)) + net = Net() + net(x, x) + + +def test_enumerate_tensor_parameter(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + index_sum = 0 + ret = () + for i, j in enumerate(x): + index_sum += i + ret += (j,) + return index_sum, ret + + x = Tensor(np.arange(2 * 3).reshape(2, 3)) net = Net() - net(x, x, x) + net(x) def test_enumerate_tuple_const_1(): @@ -115,23 +146,59 @@ def test_enumerate_tuple_const_1(): assert net() == (6, 110) +def test_enumerate_tensor_const_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor(np.arange(2*3).reshape(2, 3)) + + def construct(self): + index_sum = 0 + ret = () + for i in enumerate(self.value): + index_sum += i[0] + ret += (i[1],) + return index_sum, ret + + net = Net() + net() + + def test_enumerate_tuple_parameter_1(): class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - def construct(self, x, y, z): + def construct(self, x, y): index_sum = 0 - value = (x, y, z) + value = (x, y) ret = () for i in enumerate(value): index_sum += i[0] ret += (i[1],) return index_sum, ret - x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) + x = Tensor(np.arange(4)) + net = Net() + net(x, x) + + +def test_enumerate_tensor_parameter_1(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + index_sum = 0 + ret = () + for i in enumerate(x): + index_sum += i[0] + ret += (i[1],) + return index_sum, ret + + x = Tensor(np.arange(2 * 3).reshape(2, 3)) net = Net() - net(x, x, x) + net(x) def test_enumerate_tuple_const_2(): @@ -152,38 +219,59 @@ def test_enumerate_tuple_const_2(): assert net() == (10, 110) +def test_enumerate_tensor_const_2(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.value = Tensor(np.arange(2 * 3).reshape(2, 3)) + + def construct(self): + index_sum = 0 + ret = () + for i in enumerate(self.value, 1): + index_sum += i[0] + ret += (i[1],) + return index_sum, ret + + net = Net() + net() + + def test_enumerate_tuple_parameter_2(): class Net(nn.Cell): def __init__(self): super(Net, self).__init__() - def construct(self, x, y, z): + def construct(self, x, y): index_sum = 0 - value = (x, y, z) + value = (x, y) ret = () - for i in enumerate(value, 2): + for i in enumerate(value, 1): index_sum += i[0] ret += (i[1],) return index_sum, ret - x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) + x = Tensor(np.arange(4)) net = Net() - net(x, x, x) + net(x, x) -def test_enumerate_first_input_type_error(): +def test_enumerate_tensor_parameter_2(): class Net(nn.Cell): def __init__(self): super(Net, self).__init__() def construct(self, x): - return enumerate(x) + index_sum = 0 + ret = () + for i, j in enumerate(x, 1): + index_sum += i + ret += (j,) + return index_sum, ret - x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) + x = Tensor(np.arange(2 * 3).reshape(2, 3)) net = Net() - with pytest.raises(TypeError) as ex: - net(x) - assert "For 'enumerate', the 'first input'" in str(ex.value) + net(x) def test_enumerate_start_type_error():