diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 6e3d1fb8c6..a3b9c1cd5b 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -359,22 +359,22 @@ def get_index_tensor_dtype(dtype): @constexpr -def check_index_tensors_dtype(dtypes, op_name): +def check_index_tensors_dtype(indexes_types, op_name): """Check a tuple of tensor data type.""" - for ele in dtypes: - if not ele in mstype.int_type: - raise IndexError(f"For '{op_name}', the all index tensor " - f"data types should be mstype.int32, but got {dtypes}.") + for index_type in indexes_types: + if not index_type in (mstype.int32, mstype.int64): + raise IndexError(f"For '{op_name}', the all index tensor data types should be " + f"mstype.int32, but got {index_type}.") return True @constexpr -def check_index_tensor_dtype(dtype, op_name): +def check_index_tensor_dtype(index_type, op_name): """Check a tensor data type.""" - if dtype in mstype.int_type: + if index_type in (mstype.int32, mstype.int64): return True - raise IndexError( - f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") + raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, " + f"but got {index_type}.") @constexpr