diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 936099a4fb..d06ba8fa56 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -108,7 +108,8 @@ def enumerate_(x, start=0): """Enumerate list or tuple.""" x_type = F.typeof(x) ret = () - if check_is_tuple_or_list(x_type, "enumerate"): + op_name = "enumerate" + if check_is_tuple_or_list(x_type, op_name, "first input") and check_is_const_int(start, op_name, "start"): ret = zip(range(start, start + len(x)), x) return ret @@ -123,11 +124,22 @@ def while_cond(x): @constexpr -def check_is_tuple_or_list(x, op_name): +def check_is_tuple_or_list(x, op_name, arg_name): """check whether x is list or tuple.""" if isinstance(x, (mstype.list_type, mstype.tuple_type)): return True - raise TypeError(f"For '{op_name}', the input parameter should be tuple or list, but got {x}.") + raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list, but got {x}.") + + +@constexpr +def check_is_const_int(x, op_name, arg_name): + """check whether x is const int.""" + if x is None: + raise ValueError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got not const.") + if not isinstance(x, int): + raise ValueError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.") + return True + @constexpr def check_is_tensor_bool_cond(shp): diff --git a/tests/ut/python/pipeline/parse/test_enumerate.py b/tests/ut/python/pipeline/parse/test_enumerate.py index cd808696f1..c6d4e08b7d 100644 --- a/tests/ut/python/pipeline/parse/test_enumerate.py +++ b/tests/ut/python/pipeline/parse/test_enumerate.py @@ -91,6 +91,7 @@ def test_enumerate_tuple_parameter(): index_sum += i ret += (j,) return index_sum, ret + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) net = Net() net(x, x, x) @@ -127,10 +128,12 @@ def test_enumerate_tuple_parameter_1(): index_sum += i[0] ret += (i[1],) return index_sum, ret + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) net = Net() net(x, x, x) + def test_enumerate_tuple_const_2(): class Net(nn.Cell): def __init__(self): @@ -162,20 +165,37 @@ def test_enumerate_tuple_parameter_2(): index_sum += i[0] ret += (i[1],) return index_sum, ret + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) net = Net() net(x, x, x) -def test_enumerate_parameter_type_error(): +def test_enumerate_first_input_type_error(): class Net(nn.Cell): def __init__(self): super(Net, self).__init__() def construct(self, x): return enumerate(x) + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) net = Net() with pytest.raises(TypeError) as ex: net(x) - assert "For 'enumerate', the input parameter should be tuple or list" in str(ex.value) + assert "For 'enumerate', the 'first input'" in str(ex.value) + + +def test_enumerate_start_type_error(): + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + + def construct(self, x): + return enumerate(x, start=1.2) + + x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) + net = Net() + with pytest.raises(ValueError) as ex: + net((x, x)) + assert "For 'enumerate', the 'start'" in str(ex.value)