|
|
|
@ -359,22 +359,22 @@ def get_index_tensor_dtype(dtype):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_index_tensors_dtype(dtypes, op_name):
|
|
|
|
|
def check_index_tensors_dtype(indexes_types, op_name):
|
|
|
|
|
"""Check a tuple of tensor data type."""
|
|
|
|
|
for ele in dtypes:
|
|
|
|
|
if not ele in mstype.int_type:
|
|
|
|
|
raise IndexError(f"For '{op_name}', the all index tensor "
|
|
|
|
|
f"data types should be mstype.int32, but got {dtypes}.")
|
|
|
|
|
for index_type in indexes_types:
|
|
|
|
|
if not index_type in (mstype.int32, mstype.int64):
|
|
|
|
|
raise IndexError(f"For '{op_name}', the all index tensor data types should be "
|
|
|
|
|
f"mstype.int32, but got {index_type}.")
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_index_tensor_dtype(dtype, op_name):
|
|
|
|
|
def check_index_tensor_dtype(index_type, op_name):
|
|
|
|
|
"""Check a tensor data type."""
|
|
|
|
|
if dtype in mstype.int_type:
|
|
|
|
|
if index_type in (mstype.int32, mstype.int64):
|
|
|
|
|
return True
|
|
|
|
|
raise IndexError(
|
|
|
|
|
f"For '{op_name}', the index tensor data type should be mstype.int32, but got {dtype}.")
|
|
|
|
|
raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, "
|
|
|
|
|
f"but got {index_type}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|