!12127 getitem support shape 0

From: @yepei6
Reviewed-by: @kisnwang,@kingxian
Signed-off-by: @kingxian
pull/12127/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b6183f718f

@ -129,7 +129,7 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
return tuple_index_new return tuple_index_new
def _expand_data_dims(data, tuple_index, op_name): def _expand_data_dims(data, tuple_index):
"""expand the data's dim with 'None' and 'Boolean' in tuple_index""" """expand the data's dim with 'None' and 'Boolean' in tuple_index"""
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
expand_positions, tuple_index_new = (), () expand_positions, tuple_index_new = (), ()
@ -203,8 +203,14 @@ def tensor_index_by_list(data, list_index):
indexes_types = hyper_map(F.typeof, list_index) indexes_types = hyper_map(F.typeof, list_index)
if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)): if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)):
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM) sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM)
if not sub_tuple_index:
data_rank = len(data_shape)
if data_rank == 1:
return const_utils.make_tensor([], data.dtype, ())
return const_utils.make_tensor([], data.dtype, data_shape[1:])
tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64) tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64)
return F.gather(data, tensor_index, 0) return F.gather(data, tensor_index, 0)
tuple_index_new = () tuple_index_new = ()
for index in list_index: for index in list_index:
tuple_index_new += (index,) tuple_index_new += (index,)
@ -219,7 +225,7 @@ def tensor_index_by_tuple(data, tuple_index):
op_name = const_utils.TENSOR_GETITEM op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index, op_name) data, tuple_index = _expand_data_dims(data, tuple_index)
data_shape = F.shape(data) data_shape = F.shape(data)
data_rank = len(data_shape) data_rank = len(data_shape)
@ -228,6 +234,7 @@ def tensor_index_by_tuple(data, tuple_index):
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
if contain_type == const_utils.ALL_TENSOR: if contain_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name) return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name)
if contain_type == const_utils.ALL_BASIC: if contain_type == const_utils.ALL_BASIC:
@ -245,7 +252,9 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
tensor_index_shape = hyper_map(F.shape, tuple_index) tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
if 0 in broadcast_shape: if 0 in broadcast_shape:
res_shape = broadcast_shape + data_shape[tuple_index_len:] res_shape = broadcast_shape
if tuple_index_len < len(data_shape):
res_shape += data_shape[tuple_index_len:]
res = const_utils.make_tensor([], data.dtype, res_shape) res = const_utils.make_tensor([], data.dtype, res_shape)
return res return res
@ -268,12 +277,68 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
def _tensor_getitem_by_tuple(data, tuple_index, op_name): def _tensor_getitem_by_tuple(data, tuple_index, op_name):
"""Tensor getitem by a tuple of mixed tensor.""" """Tensor getitem by a tuple of mixed tensor."""
indices = _generate_indices_from_tuple(data, tuple_index, op_name) data_shape = F.shape(data)
data_rank = len(data_shape)
tuple_index_len = len(tuple_index)
tensor_indexes, slice_indexes = [], []
indexes_types = hyper_map(F.typeof, tuple_index)
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
tuple_index_new = ()
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
if i in int_positions:
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name)
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
elif i in sequence_positions:
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
tensor_index = const_utils.make_tensor(sequence_index)
tensor_index = F.cast(tensor_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
elif i in tensor_positions:
const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name)
tensor_index = F.cast(index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
elif i in slice_positions:
slice_indexes.append(index)
tuple_index_new += (index,)
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
indexes_types = hyper_map(F.typeof, tuple_index_new)
broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors(
data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name)
if 0 in final_shape:
if tuple_index_len < data_rank:
final_shape = final_shape + data_shape[tuple_index_len:]
return const_utils.make_tensor([], data.dtype, final_shape)
slice_number = 0
final_index_tensors = []
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_len):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
tuple_index_new[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
indices = pack(final_index_tensors)
result = F.gather_nd(data, indices) result = F.gather_nd(data, indices)
return result return result
def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor.""" """Generate an indices tensor from a tuple of tensor."""
indices = None indices = None
indexes_types = hyper_map(F.dtype, tuple_index) indexes_types = hyper_map(F.dtype, tuple_index)
@ -510,13 +575,13 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
return data return data
op_name = const_utils.TENSOR_GETITEM op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index, op_name) data, tuple_index = _expand_data_dims(data, tuple_index)
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR: if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
else: else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT: if int_cnt == const_utils.ALL_INT:
@ -572,13 +637,13 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
return data return data
op_name = const_utils.TENSOR_GETITEM op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index, op_name) data, tuple_index = _expand_data_dims(data, tuple_index)
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR: if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
else: else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT: if int_cnt == const_utils.ALL_INT:
@ -600,13 +665,13 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
return data return data
op_name = const_utils.TENSOR_GETITEM op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index, op_name) data, tuple_index = _expand_data_dims(data, tuple_index)
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR: if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
else: else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT: if int_cnt == const_utils.ALL_INT:

Loading…
Cancel
Save