getitem support int64 tensor index

pull/10403/head
Payne 4 years ago
parent 2599aefad0
commit 4d97a4e2e8

@ -345,22 +345,22 @@ def get_index_tensor_dtype(dtype):
@constexpr
def check_index_tensors_dtype(dtypes, op_name):
def check_index_tensors_dtype(indexes_types, op_name):
"""Check a tuple of tensor data type."""
for ele in dtypes:
if not ele in mstype.int_type:
raise IndexError(f"For '{op_name}', the all index tensor "
f"data types should be mstype.int32, but got {dtypes}.")
for index_type in indexes_types:
if not index_type in (mstype.int32, mstype.int64):
raise IndexError(f"For '{op_name}', the all index tensor data types should be "
f"mstype.int32, but got {index_type}.")
return True
@constexpr
def check_index_tensor_dtype(dtype, op_name):
def check_index_tensor_dtype(index_type, op_name):
"""Check a tensor data type."""
if dtype in mstype.int_type:
if index_type in (mstype.int32, mstype.int64):
return True
raise IndexError(
f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.")
raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, "
f"but got {index_type}.")
@constexpr

Loading…
Cancel
Save