!10403 getitem support int64 tensor index

From: @yepei6
Reviewed-by: @kingxian,@zhunaipan
Signed-off-by: @kingxian
pull/10403/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7f5b007f5a

@ -359,22 +359,22 @@ def get_index_tensor_dtype(dtype):
@constexpr
def check_index_tensors_dtype(dtypes, op_name):
def check_index_tensors_dtype(indexes_types, op_name):
"""Check a tuple of tensor data type."""
for ele in dtypes:
if not ele in mstype.int_type:
raise IndexError(f"For '{op_name}', the all index tensor "
f"data types should be mstype.int32, but got {dtypes}.")
for index_type in indexes_types:
if not index_type in (mstype.int32, mstype.int64):
raise IndexError(f"For '{op_name}', the all index tensor data types should be "
f"mstype.int32, but got {index_type}.")
return True
@constexpr
def check_index_tensor_dtype(dtype, op_name):
def check_index_tensor_dtype(index_type, op_name):
"""Check a tensor data type."""
if dtype in mstype.int_type:
if index_type in (mstype.int32, mstype.int64):
return True
raise IndexError(
f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.")
raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, "
f"but got {index_type}.")
@constexpr

Loading…
Cancel
Save