|
|
@ -57,6 +57,68 @@ def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
|
|
|
|
return indices
|
|
|
|
return indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_indices_from_tuple(data, tuple_index, op_name):
|
|
|
|
|
|
|
|
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
|
|
|
|
|
|
|
data_shape = F.shape(data)
|
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
|
|
|
|
|
|
|
int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types)
|
|
|
|
|
|
|
|
tuple_index_new = ()
|
|
|
|
|
|
|
|
tuple_len = len(tuple_index)
|
|
|
|
|
|
|
|
for i in range(tuple_len):
|
|
|
|
|
|
|
|
index = tuple_index[i]
|
|
|
|
|
|
|
|
shape = data_shape[i]
|
|
|
|
|
|
|
|
if i in int_positions:
|
|
|
|
|
|
|
|
int_index = const_utils.check_and_transform_int_index(index, shape, op_name)
|
|
|
|
|
|
|
|
tensor_index = F.scalar_to_tensor(int_index, mstype.int32)
|
|
|
|
|
|
|
|
tuple_index_new += (tensor_index,)
|
|
|
|
|
|
|
|
elif i in sequence_positions:
|
|
|
|
|
|
|
|
sequence_index = const_utils.transform_sequence_index(index, shape, op_name)
|
|
|
|
|
|
|
|
tensor_index = F.tuple_to_array(sequence_index)
|
|
|
|
|
|
|
|
tuple_index_new += (tensor_index,)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
tuple_index_new += (index,)
|
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
|
|
|
|
|
|
|
tensor_positions, slice_positions, ellipsis_position = \
|
|
|
|
|
|
|
|
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
|
|
|
|
|
|
|
|
tensor_indexes = []
|
|
|
|
|
|
|
|
slice_indexes = []
|
|
|
|
|
|
|
|
for i in tensor_positions:
|
|
|
|
|
|
|
|
tensor_indexes.append(tuple_index_new[i])
|
|
|
|
|
|
|
|
for j in slice_positions:
|
|
|
|
|
|
|
|
slice_indexes.append(tuple_index_new[j])
|
|
|
|
|
|
|
|
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
|
|
|
|
|
|
|
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
|
|
|
|
|
|
|
|
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
|
|
|
|
|
|
|
|
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
|
|
|
|
|
|
|
|
indexes_types,
|
|
|
|
|
|
|
|
tensor_indexes_shapes,
|
|
|
|
|
|
|
|
tensor_indexes_dtypes,
|
|
|
|
|
|
|
|
slice_indexes,
|
|
|
|
|
|
|
|
op_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
slice_number = 0
|
|
|
|
|
|
|
|
final_index_tensors = []
|
|
|
|
|
|
|
|
tuple_index_size = len(tuple_index_new)
|
|
|
|
|
|
|
|
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
|
|
|
|
|
|
|
|
for i in range(tuple_index_size):
|
|
|
|
|
|
|
|
if i in tensor_positions:
|
|
|
|
|
|
|
|
transform_tensor = _transform_indexing_tensor(
|
|
|
|
|
|
|
|
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
|
|
|
|
|
|
|
|
final_index_tensors.append(transform_tensor)
|
|
|
|
|
|
|
|
if i in slice_positions:
|
|
|
|
|
|
|
|
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
|
|
|
|
|
|
|
|
final_index_tensors.append(slice_tensor)
|
|
|
|
|
|
|
|
slice_number += 1
|
|
|
|
|
|
|
|
if i == ellipsis_position:
|
|
|
|
|
|
|
|
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
|
|
|
|
|
|
|
|
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
|
|
|
|
|
|
|
|
for ele in ellipsis_tensors:
|
|
|
|
|
|
|
|
final_index_tensors.append(ele)
|
|
|
|
|
|
|
|
slice_number += ellipsis_occupied_dims
|
|
|
|
|
|
|
|
indices = pack(final_index_tensors)
|
|
|
|
|
|
|
|
return indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
|
|
|
def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
|
|
|
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
|
|
|
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
|
|
|
data_shape = F.shape(data)
|
|
|
|
data_shape = F.shape(data)
|
|
|
@ -160,6 +222,8 @@ def _tensor_getitem(self, index):
|
|
|
|
return tensor_index_by_tensor(self, index)
|
|
|
|
return tensor_index_by_tensor(self, index)
|
|
|
|
if isinstance(index, tuple):
|
|
|
|
if isinstance(index, tuple):
|
|
|
|
return tensor_index_by_tuple(self, index)
|
|
|
|
return tensor_index_by_tuple(self, index)
|
|
|
|
|
|
|
|
if isinstance(index, list):
|
|
|
|
|
|
|
|
return tensor_index_by_list(self, index)
|
|
|
|
# bool type should be judged before int
|
|
|
|
# bool type should be judged before int
|
|
|
|
if isinstance(index, bool):
|
|
|
|
if isinstance(index, bool):
|
|
|
|
return _tensor_index_by_bool(self, index)
|
|
|
|
return _tensor_index_by_bool(self, index)
|
|
|
@ -187,6 +251,13 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
|
|
|
return result
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple(data, tuple_index):
|
|
|
|
|
|
|
|
"""Tensor getitem by a tuple of mixed tensor."""
|
|
|
|
|
|
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM)
|
|
|
|
|
|
|
|
result = F.gather_nd(data, indices)
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
|
|
|
|
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
|
|
|
|
"""Tensor getitem by a tuple of mixed tensor."""
|
|
|
|
"""Tensor getitem by a tuple of mixed tensor."""
|
|
|
|
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
|
|
|
|
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
|
|
|
@ -273,12 +344,12 @@ def tensor_index_by_tuple(data, tuple_index):
|
|
|
|
if len(tuple_index) == 1:
|
|
|
|
if len(tuple_index) == 1:
|
|
|
|
return data[tuple_index_without_none[0]]
|
|
|
|
return data[tuple_index_without_none[0]]
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_without_none)
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_without_none)
|
|
|
|
tensor_cnt = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM)
|
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM)
|
|
|
|
if tensor_cnt == const_utils.NO_TENSOR:
|
|
|
|
if contain_type == const_utils.ALL_TENSOR:
|
|
|
|
return _tensor_index_by_tuple_slice(data, tuple_index_without_none)
|
|
|
|
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
|
|
|
|
if tensor_cnt == const_utils.ALL_TENSOR:
|
|
|
|
if contain_type == const_utils.ALL_BASIC:
|
|
|
|
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index_without_none)
|
|
|
|
return _tensor_index_by_tuple_slice(data, tuple_index)
|
|
|
|
return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index_without_none)
|
|
|
|
return _tensor_getitem_by_tuple(data, tuple_index_without_none)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_setitem(self, index, value):
|
|
|
|
def _tensor_setitem(self, index, value):
|
|
|
|