|
|
|
@ -137,13 +137,37 @@ def _expand_data_dims_with_none(data, tuple_index, op_name):
|
|
|
|
|
none_type_tag = const_utils.judge_index_type(index_type, mstype.type_none)
|
|
|
|
|
tuple_index_without_none += (const_utils.make_empty_slice(),) if none_type_tag else(index,)
|
|
|
|
|
none_positions += (i,) if none_type_tag else ()
|
|
|
|
|
|
|
|
|
|
for dim in none_positions:
|
|
|
|
|
data = F.expand_dims(data, dim)
|
|
|
|
|
|
|
|
|
|
return data, tuple_index_without_none
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _expand_data_dims_with_bool(data, tuple_index, op_name):
|
|
|
|
|
"""expand the data's dim with 'True/False' in tuple_index"""
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
|
|
|
|
bool_positions, tuple_index_without_bool = (), ()
|
|
|
|
|
|
|
|
|
|
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
|
|
|
|
|
bool_type_tag = const_utils.judge_index_type(index_type, mstype.type_bool)
|
|
|
|
|
if bool_type_tag:
|
|
|
|
|
if index:
|
|
|
|
|
tuple_index_without_bool += (const_utils.make_tensor([0], mstype.int64),)
|
|
|
|
|
else:
|
|
|
|
|
# todo wait to complete the operations' support for zero dim-size, then could make 0 length tensor.
|
|
|
|
|
# to replace the 'False'
|
|
|
|
|
|
|
|
|
|
return const_utils.raise_index_error("When tensor is indexed by a tuple which contains bool object, "
|
|
|
|
|
"the value only support 'True'.")
|
|
|
|
|
else:
|
|
|
|
|
tuple_index_without_bool += (index,)
|
|
|
|
|
bool_positions += (i,) if bool_type_tag else ()
|
|
|
|
|
|
|
|
|
|
for dim in bool_positions:
|
|
|
|
|
data = F.expand_dims(data, dim)
|
|
|
|
|
|
|
|
|
|
return data, tuple_index_without_bool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_slice(data, slice_index):
|
|
|
|
|
"""Tensor getitem by a single slice"""
|
|
|
|
|
shape = F.shape(data)
|
|
|
|
@ -168,7 +192,7 @@ def _tensor_index_by_bool(data, bool_value):
|
|
|
|
|
"""Tensor getitem by a single bool value"""
|
|
|
|
|
if bool_value:
|
|
|
|
|
return F.expand_dims(data, 0)
|
|
|
|
|
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
|
|
|
|
|
return const_utils.make_tensor([], data.dtype, (0,) + F.shape(data))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_index_by_integer(data, number):
|
|
|
|
@ -207,8 +231,11 @@ def tensor_index_by_tuple(data, tuple_index):
|
|
|
|
|
op_name = const_utils.TENSOR_GETITEM
|
|
|
|
|
if len(tuple_index) == 1:
|
|
|
|
|
return data[tuple_index[0]]
|
|
|
|
|
|
|
|
|
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
|
|
|
|
|
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)
|
|
|
|
|
data, tuple_index = _expand_data_dims_with_bool(data, tuple_index, op_name)
|
|
|
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
|
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
|
|
|
|
|
if contain_type == const_utils.ALL_TENSOR:
|
|
|
|
@ -228,8 +255,8 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
|
|
|
|
|
def _tensor_getitem_by_tuple_slice(data, tuple_index):
|
|
|
|
|
"""Tensor getitem by a tuple of slice"""
|
|
|
|
|
data_shape = F.shape(data)
|
|
|
|
|
begin_strides, end_strides, step_strides, shrink_axis_mask = \
|
|
|
|
|
const_utils.get_stride_info_from_tuple(data_shape, tuple_index)
|
|
|
|
|
begin_strides, end_strides, step_strides, shrink_axis_mask = const_utils.get_stride_info_from_tuple(
|
|
|
|
|
data_shape, tuple_index)
|
|
|
|
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -259,8 +286,8 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
|
|
|
|
|
tuple_index_len = len(tuple_index)
|
|
|
|
|
tensor_indexes, slice_indexes = [], []
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
|
|
|
|
slice_positions, _, _, int_positions, _, \
|
|
|
|
|
tensor_positions, sequence_positions = const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
|
|
|
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
|
|
|
|
|
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
|
|
|
|
|
tuple_index_new = ()
|
|
|
|
|
|
|
|
|
|
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
|
|
|
|
@ -296,8 +323,8 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
|
|
|
|
|
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
|
|
|
|
|
for i in range(tuple_index_len):
|
|
|
|
|
if i in tensor_positions:
|
|
|
|
|
transform_tensor = _transform_indexing_tensor(
|
|
|
|
|
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
|
|
|
|
|
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
|
|
|
|
|
tuple_index_new[i])
|
|
|
|
|
final_index_tensors.append(transform_tensor)
|
|
|
|
|
if i in slice_positions:
|
|
|
|
|
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
|
|
|
|
@ -321,6 +348,7 @@ def _generate_updates_from_tuple(data, index, value, op_type):
|
|
|
|
|
value_types = hyper_map(F.typeof, value)
|
|
|
|
|
data_dtype = F.dtype(data)
|
|
|
|
|
value_elements_type = const_utils.check_value_elements(data_dtype, value_types)
|
|
|
|
|
|
|
|
|
|
if value_elements_type == const_utils.ALL_TENSOR:
|
|
|
|
|
value_shapes = hyper_map(F.shape, value)
|
|
|
|
|
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
|
|
|
|
|