From 5c9982729d5414774e02b32665e571e91b85aadd Mon Sep 17 00:00:00 2001 From: Payne Date: Tue, 22 Dec 2020 17:54:45 +0800 Subject: [PATCH] fancy index getitem --- .../composite/multitype_ops/_compile_utils.py | 83 +++++++++++++++++-- .../multitype_ops/_constexpr_utils.py | 54 ++++++++++++ .../composite/multitype_ops/getitem_impl.py | 30 +++---- ...cy_index.py => test_tensor_fancy_index.py} | 73 ++++++++-------- 4 files changed, 184 insertions(+), 56 deletions(-) rename tests/ut/python/ops/{ test_tensor_fancy_index.py => test_tensor_fancy_index.py} (52%) diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 6f8186c627..93e1f8e17a 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -57,6 +57,68 @@ def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): return indices +def _generate_indices_from_tuple(data, tuple_index, op_name): + """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" + data_shape = F.shape(data) + indexes_types = hyper_map(F.typeof, tuple_index) + int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types) + tuple_index_new = () + tuple_len = len(tuple_index) + for i in range(tuple_len): + index = tuple_index[i] + shape = data_shape[i] + if i in int_positions: + int_index = const_utils.check_and_transform_int_index(index, shape, op_name) + tensor_index = F.scalar_to_tensor(int_index, mstype.int32) + tuple_index_new += (tensor_index,) + elif i in sequence_positions: + sequence_index = const_utils.transform_sequence_index(index, shape, op_name) + tensor_index = F.tuple_to_array(sequence_index) + tuple_index_new += (tensor_index,) + else: + tuple_index_new += (index,) + indexes_types = hyper_map(F.typeof, tuple_index_new) + tensor_positions, slice_positions, ellipsis_position = \ + const_utils.separate_mixed_tensors_index(indexes_types, op_name) + tensor_indexes = [] + slice_indexes = [] + for i in tensor_positions: + tensor_indexes.append(tuple_index_new[i]) + for j in slice_positions: + slice_indexes.append(tuple_index_new[j]) + tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) + tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) + broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ + const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape, + indexes_types, + tensor_indexes_shapes, + tensor_indexes_dtypes, + slice_indexes, + op_name) + + slice_number = 0 + final_index_tensors = [] + tuple_index_size = len(tuple_index_new) + index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) + for i in range(tuple_index_size): + if i in tensor_positions: + transform_tensor = _transform_indexing_tensor( + broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]) + final_index_tensors.append(transform_tensor) + if i in slice_positions: + slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) + final_index_tensors.append(slice_tensor) + slice_number += 1 + if i == ellipsis_position: + ellipsis_tensors = const_utils.convert_ellipsis_to_tensors( + slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name) + for ele in ellipsis_tensors: + final_index_tensors.append(ele) + slice_number += ellipsis_occupied_dims + indices = pack(final_index_tensors) + return indices + + def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" data_shape = F.shape(data) @@ -160,6 +222,8 @@ def _tensor_getitem(self, index): return tensor_index_by_tensor(self, index) if isinstance(index, tuple): return tensor_index_by_tuple(self, index) + if isinstance(index, list): + return tensor_index_by_list(self, index) # bool type should be judged before int if isinstance(index, bool): return _tensor_index_by_bool(self, index) @@ -187,6 +251,13 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): return result +def _tensor_getitem_by_tuple(data, tuple_index): + """Tensor getitem by a tuple of mixed tensor.""" + indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM) + result = F.gather_nd(data, indices) + return result + + def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): """Tensor getitem by a tuple of mixed tensor.""" indices = _generate_indices_from_tuple_of_mixed_tensors(data, @@ -273,12 +344,12 @@ def tensor_index_by_tuple(data, tuple_index): if len(tuple_index) == 1: return data[tuple_index_without_none[0]] indexes_types = hyper_map(F.typeof, tuple_index_without_none) - tensor_cnt = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM) - if tensor_cnt == const_utils.NO_TENSOR: - return _tensor_index_by_tuple_slice(data, tuple_index_without_none) - if tensor_cnt == const_utils.ALL_TENSOR: - return _tensor_getitem_by_tuple_of_tensor(data, tuple_index_without_none) - return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index_without_none) + contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM) + if contain_type == const_utils.ALL_TENSOR: + return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) + if contain_type == const_utils.ALL_BASIC: + return _tensor_index_by_tuple_slice(data, tuple_index) + return _tensor_getitem_by_tuple(data, tuple_index_without_none) def _tensor_setitem(self, index, value): diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index ddfdfc6927..441b307f49 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -31,6 +31,8 @@ ALL_SCALAR = 3 ALL_INT = 4 NO_INT = 5 CONTAIN_INT = 6 +ALL_BASIC = 7 +MIXED = 8 INT_ = 0 BOOL_ = 1 @@ -307,6 +309,18 @@ def tuple_index_int_cnt(types, op_name): return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT +@constexpr +def tuple_index_type_cnt(types, op_name): + """count the tensor type of types which contains the tuple elements' type.""" + tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) + basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.ellipsis_type, mstype.slice_type)) for ele in types) + if tensor_cnt == len(types): + return ALL_TENSOR + if basic_cnt == len(types): + return ALL_BASIC + return MIXED + + @constexpr def check_value_elements(data_dtype, types): """Judges the type of all elements of the tuple.""" @@ -501,6 +515,34 @@ def convert_ellipsis_to_tensors(slice_number, return tensor_list +@constexpr +def check_and_transform_int_index(index, shape, op_name): + if index < -shape or index >= shape: + raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit " + f"the corresponding dim length, but get {index}.") + if index < 0: + index += shape + return index + + +@constexpr +def transform_sequence_index(sequence_index, shape, op_name): + """transform list or tuple with integer and boolean to tuple with integer index""" + bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index))) + int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count + if int_count == 0: + if bool_count == shape: + list_index = list(filter(lambda i: sequence_index[i], range(bool_count))) + else: + raise IndexError("The boolean array should have the same length with the corresponding dimensiton") + else: + list_index = [int(index) for index in sequence_index] + for i, index in enumerate(list_index): + list_index[i] = check_and_transform_int_index(index, shape, op_name) + sub_tuple_index = tuple(list_index) + return sub_tuple_index + + @constexpr def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): """Convert a slice to a tensor.""" @@ -702,6 +744,18 @@ def get_pos_of_int_index(indexes_types): return int_positions +@constexpr +def get_pos_of_int_sequence(indexes_types): + """Get int and sequence index positions from the mixed tensors index.""" + int_positions, sequence_positions = [], [] + for i, index_type in enumerate(indexes_types): + if isinstance(index_type, mstype.Int): + int_positions.append(i) + elif isinstance(index_type, (tuple, list)): + sequence_positions.append(i) + return int_positions, sequence_positions + + @constexpr def separate_mixed_tensors_index(indexes_types, op_name): """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 20d4386849..8ec376162a 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -206,21 +206,6 @@ def _tensor_getitem_by_tensor(data, tensor_index): return compile_utils.tensor_index_by_tensor(data, tensor_index) -@getitem.register("Tensor", "Tuple") -def _tensor_getitem_by_tuple(data, tuple_index): - """ - Getting item of tensor by tuple. - - Inputs: - data (Tensor): A tensor. - tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple. - - Outputs: - Tensor, element type is the same as the element type of data. - """ - return compile_utils.tensor_index_by_tuple(data, tuple_index) - - @getitem.register("Tensor", "Ellipsis") def _tensor_getitem_by_ellipsis(data, ellipsis_index): """ @@ -249,3 +234,18 @@ def _tensor_getitem_by_list(data, list_index): Tensor ,same as data. """ return compile_utils.tensor_index_by_list(data, list_index) + + +@getitem.register("Tensor", "Tuple") +def _tensor_getitem_by_tuple(data, tuple_index): + """ + Getting item of tensor by tuple. + + Inputs: + data (Tensor): A tensor. + tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple. + + Outputs: + Tensor, element type is the same as the element type of data. + """ + return compile_utils.tensor_index_by_tuple(data, tuple_index) diff --git a/tests/ut/python/ops/ test_tensor_fancy_index.py b/tests/ut/python/ops/test_tensor_fancy_index.py similarity index 52% rename from tests/ut/python/ops/ test_tensor_fancy_index.py rename to tests/ut/python/ops/test_tensor_fancy_index.py index 1d883a9770..70fafde75a 100644 --- a/tests/ut/python/ops/ test_tensor_fancy_index.py +++ b/tests/ut/python/ops/test_tensor_fancy_index.py @@ -21,61 +21,64 @@ from mindspore import dtype as mstype from mindspore.nn import Cell -class NetWorkFancyIndexBoolean(Cell): +class NetWorkFancyIndex(Cell): def __init__(self, index): - super(NetWorkFancyIndexBoolean, self).__init__() + super(NetWorkFancyIndex, self).__init__() self.index = index def construct(self, tensor): return tensor[self.index] -class NetWorkFancyIndexInterger(Cell): - def __init__(self, index): - super(NetWorkFancyIndexInterger, self).__init__() - self.index = index - - def construct(self, tensor): - return tensor[self.index] - - -class NetWorkFancyIndexIntergerBooleanMixed(Cell): - def __init__(self, index): - super(NetWorkFancyIndexIntergerBooleanMixed, self).__init__() - self.index = index - - def construct(self, tensor): - return tensor[self.index] - - -def test_tensor_fancy_index_integer_list(): +def test_tensor_fancy_index_integer_list_graph(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) index = [0, 2, 1] - net = NetWorkFancyIndexBoolean(index) + net = NetWorkFancyIndex(index) input_np = np.arange(60).reshape(3, 4, 5) input_me = Tensor(input_np, dtype=mstype.float32) - output_me = net(input_me).asnumpy() - output_np = input_np[index] - assert np.allclose(output_np, output_me, 0, 0) + net(input_me) -def test_tensor_fancy_boolean_list(): +def test_tensor_fancy_boolean_list_graph(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) index = [True, True, False] - net = NetWorkFancyIndexInterger(index) + net = NetWorkFancyIndex(index) input_np = np.arange(60).reshape(3, 4, 5) input_me = Tensor(input_np, dtype=mstype.float32) - output_me = net(input_me).asnumpy() - output_np = input_np[index] - assert np.allclose(output_np, output_me, 0, 0) + net(input_me) -def test_tensor_fancy_integer_boolean_list(): +def test_tensor_fancy_integer_boolean_list_graph(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) index = [1, 2, True, False] - net = NetWorkFancyIndexIntergerBooleanMixed(index) + net = NetWorkFancyIndex(index) input_np = np.arange(60).reshape(3, 4, 5) input_me = Tensor(input_np, dtype=mstype.float32) - output_me = net(input_me).asnumpy() - output_np = input_np[index] - assert np.allclose(output_np, output_me, 0, 0) + net(input_me) + + +def test_tensor_fancy_integer_list_mixed_graph(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + index = (1, [2, 1, 3], slice(1, 3, 1), ..., 4) + net = NetWorkFancyIndex(index) + input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8) + input_me = Tensor(input_np, dtype=mstype.float32) + net(input_me) + + +def test_tensor_fancy_integer_tuple_mixed_graph(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + index = (1, (2, 1, 3), slice(1, 3, 1), ..., 4) + net = NetWorkFancyIndex(index) + input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8) + input_me = Tensor(input_np, dtype=mstype.float32) + net(input_me) + + +def test_tensor_fancy_integer_list_tuple_mixed_graph(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + index = (1, [2, 1, 3], (3, 2, 1), slice(1, 3, 1), ..., 4) + net = NetWorkFancyIndex(index) + input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8) + input_me = Tensor(input_np, dtype=mstype.float32) + net(input_me)