diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index fbce7007a3..5a90575e96 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -129,7 +129,7 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name): return tuple_index_new -def _expand_data_dims(data, tuple_index, op_name): +def _expand_data_dims(data, tuple_index): """expand the data's dim with 'None' and 'Boolean' in tuple_index""" indexes_types = hyper_map(F.typeof, tuple_index) expand_positions, tuple_index_new = (), () @@ -203,8 +203,14 @@ def tensor_index_by_list(data, list_index): indexes_types = hyper_map(F.typeof, list_index) if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)): sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM) + if not sub_tuple_index: + data_rank = len(data_shape) + if data_rank == 1: + return const_utils.make_tensor([], data.dtype, ()) + return const_utils.make_tensor([], data.dtype, data_shape[1:]) tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64) return F.gather(data, tensor_index, 0) + tuple_index_new = () for index in list_index: tuple_index_new += (index,) @@ -219,7 +225,7 @@ def tensor_index_by_tuple(data, tuple_index): op_name = const_utils.TENSOR_GETITEM tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims(data, tuple_index) data_shape = F.shape(data) data_rank = len(data_shape) @@ -228,6 +234,7 @@ def tensor_index_by_tuple(data, tuple_index): indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) + if contain_type == const_utils.ALL_TENSOR: return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name) if contain_type == const_utils.ALL_BASIC: @@ -245,7 +252,9 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name): tensor_index_shape = hyper_map(F.shape, tuple_index) broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) if 0 in broadcast_shape: - res_shape = broadcast_shape + data_shape[tuple_index_len:] + res_shape = broadcast_shape + if tuple_index_len < len(data_shape): + res_shape += data_shape[tuple_index_len:] res = const_utils.make_tensor([], data.dtype, res_shape) return res @@ -268,12 +277,68 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index): def _tensor_getitem_by_tuple(data, tuple_index, op_name): """Tensor getitem by a tuple of mixed tensor.""" - indices = _generate_indices_from_tuple(data, tuple_index, op_name) + data_shape = F.shape(data) + data_rank = len(data_shape) + tuple_index_len = len(tuple_index) + tensor_indexes, slice_indexes = [], [] + indexes_types = hyper_map(F.typeof, tuple_index) + slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \ + const_utils.get_pos_of_indexes_types(indexes_types, op_name) + tuple_index_new = () + + for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)): + if i in int_positions: + int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name) + tensor_index = F.scalar_to_tensor(int_index, mstype.int64) + tuple_index_new += (tensor_index,) + tensor_indexes.append(tensor_index) + tensor_positions.append(i) + elif i in sequence_positions: + sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name) + tensor_index = const_utils.make_tensor(sequence_index) + tensor_index = F.cast(tensor_index, mstype.int64) + tuple_index_new += (tensor_index,) + tensor_indexes.append(tensor_index) + tensor_positions.append(i) + elif i in tensor_positions: + const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name) + tensor_index = F.cast(index, mstype.int64) + tuple_index_new += (tensor_index,) + tensor_indexes.append(tensor_index) + elif i in slice_positions: + slice_indexes.append(index) + tuple_index_new += (index,) + + tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) + tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) + indexes_types = hyper_map(F.typeof, tuple_index_new) + broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors( + data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name) + + if 0 in final_shape: + if tuple_index_len < data_rank: + final_shape = final_shape + data_shape[tuple_index_len:] + return const_utils.make_tensor([], data.dtype, final_shape) + + slice_number = 0 + final_index_tensors = [] + index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) + for i in range(tuple_index_len): + if i in tensor_positions: + transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, + tuple_index_new[i]) + final_index_tensors.append(transform_tensor) + if i in slice_positions: + slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) + final_index_tensors.append(slice_tensor) + slice_number += 1 + + indices = pack(final_index_tensors) result = F.gather_nd(data, indices) return result -def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): +def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): """Generate an indices tensor from a tuple of tensor.""" indices = None indexes_types = hyper_map(F.dtype, tuple_index) @@ -510,13 +575,13 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): return data op_name = const_utils.TENSOR_GETITEM tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims(data, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) if contain_type == const_utils.ALL_TENSOR: - indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) + indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) else: int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) if int_cnt == const_utils.ALL_INT: @@ -572,13 +637,13 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): return data op_name = const_utils.TENSOR_GETITEM tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims(data, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) if contain_type == const_utils.ALL_TENSOR: - indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) + indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) else: int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) if int_cnt == const_utils.ALL_INT: @@ -600,13 +665,13 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): return data op_name = const_utils.TENSOR_GETITEM tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims(data, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) if contain_type == const_utils.ALL_TENSOR: - indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) + indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) else: int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) if int_cnt == const_utils.ALL_INT: