!11502 getitem add tensor-index check

From: @yepei6
Reviewed-by: 
Signed-off-by:
pull/11502/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 82f68350f1

@ -213,10 +213,19 @@ def tensor_index_by_list(data, list_index):
def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types with None"""
tuple_index_len = len(tuple_index)
if tuple_index_len == 0:
return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index, op_name)
data_shape = F.shape(data)
data_rank = len(data_shape)
min_data_rank, max_data_rank = 0, 8
const_utils.judge_data_rank(data_rank, min_data_rank, max_data_rank)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
if contain_type == const_utils.ALL_TENSOR:
@ -280,12 +289,13 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
tensor_positions.append(i)
elif i in sequence_positions:
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
tensor_index = F.tuple_to_array(sequence_index)
tensor_index = const_utils.make_tensor(sequence_index)
tensor_index = F.cast(tensor_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
elif i in tensor_positions:
const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name)
tensor_index = F.cast(index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)

@ -80,6 +80,13 @@ def make_tensor(data, data_type=mstype.int64, data_shape=None):
return Tensor(data, data_type)
@constexpr
def judge_data_rank(data_rank, min_data_rank=0, max_data_rank=8):
if data_rank < min_data_rank or data_rank > max_data_rank:
raise ValueError(f"The input data's rank should in the range of[{min_data_rank}, "
f"{max_data_rank}], bug actually is '{data_rank}'")
@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
@ -148,14 +155,14 @@ def judge_index_type(index_type, target_type):
def check_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise TypeError(
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.")
@constexpr
def check_index_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise IndexError(
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.")
@constexpr
@ -330,12 +337,12 @@ def integer_to_indices(index, shape):
@constexpr
def tuple_element_is_int(indexs):
def tuple_element_is_int(indexes):
"""Judges tuple element type."""
if not indexs:
if not indexes:
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if isinstance(indexes, tuple):
for _, ele in enumerate(indexes):
if not isinstance(ele, int):
return False
return True
@ -509,7 +516,7 @@ def transform_sequence_index(sequence_index, shape, op_name):
if bool_count == shape:
list_index = list(filter(lambda i: sequence_index[i], range(bool_count)))
else:
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
raise IndexError("The boolean array should have the same length with the corresponding dimension")
else:
list_index = [int(index) for index in sequence_index]

Loading…
Cancel
Save