!10403 getitem support int64 tensor index

From: @yepei6
Reviewed-by: @kingxian,@zhunaipan
Signed-off-by: @kingxian
pull/10403/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7f5b007f5a

@ -359,22 +359,22 @@ def get_index_tensor_dtype(dtype):
@constexpr @constexpr
def check_index_tensors_dtype(dtypes, op_name): def check_index_tensors_dtype(indexes_types, op_name):
"""Check a tuple of tensor data type.""" """Check a tuple of tensor data type."""
for ele in dtypes: for index_type in indexes_types:
if not ele in mstype.int_type: if not index_type in (mstype.int32, mstype.int64):
raise IndexError(f"For '{op_name}', the all index tensor " raise IndexError(f"For '{op_name}', the all index tensor data types should be "
f"data types should be mstype.int32, but got {dtypes}.") f"mstype.int32, but got {index_type}.")
return True return True
@constexpr @constexpr
def check_index_tensor_dtype(dtype, op_name): def check_index_tensor_dtype(index_type, op_name):
"""Check a tensor data type.""" """Check a tensor data type."""
if dtype in mstype.int_type: if index_type in (mstype.int32, mstype.int64):
return True return True
raise IndexError( raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, "
f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.") f"but got {index_type}.")
@constexpr @constexpr

Loading…
Cancel
Save