diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 184e23e684..4848084f26 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -23,7 +23,7 @@ from ....common import dtype as mstype from ....common._register_for_tensor import tensor_operator_registry hyper_map = base.HyperMap() -pack = P.Stack(axis=-1) +stack = P.Stack(axis=-1) def _tensor_getitem(self, index): @@ -36,44 +36,35 @@ def _tensor_getitem(self, index): return tensor_index_by_tuple(self, (index,)) raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, " f"list and tuple ,but got {index} with type {type(index)}.") - +tensor_operator_registry.register("__getitem__", _tensor_getitem) def _tensor_setitem(self, index, value): - """Handle tensor getitem""" + """Handle tensor setitem""" + if not isinstance(value, (int, float, bool, list, tuple, Tensor)): + raise ValueError(f"only support numbers, Tensor, tuple, list as value," + f"but got {value} with type {type(value)}.") + + if isinstance(index, list): + index = format_list_indices(index, self.shape[0]) if isinstance(index, Tensor): - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_tensor_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_tensor_with_tensor(self, index, value) - if isinstance(value, tuple): - return tensor_setitem_by_tensor_with_tuple(self, index, value) + return tensor_setitem_by_tensor(self, index, value) if isinstance(index, tuple): - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_tuple_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_tuple_with_tensor(self, index, value) - if isinstance(value, tuple): - return tensor_setitem_by_tuple_with_tuple(self, index, value) + if tuple_indices_have_false(index): + return self + index = format_tuple_indices(index) + return tensor_setitem_by_tuple(self, index, value) + if isinstance(index, bool): + return tensor_setitem_by_bool(self, index, value) if isinstance(index, int): - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_number_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_number_with_tensor(self, index, value) + return tensor_setitem_by_number(self, index, value) if isinstance(index, slice): - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_slice_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_slice_with_tensor(self, index, value) - if isinstance(index, bool): - return _tensor_index_by_bool(self, index) + return tensor_setitem_by_slice(self, index, value) if index is ...: - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_ellipsis_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_ellipsis_with_tensor(self, index, value) + return tensor_setitem_by_ellipsis(self, index, value) + raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ and tensor with int32, got {} with type{}".format(index, type(index))) - +tensor_operator_registry.register("__setitem__", _tensor_setitem) def _broadcast(broadcast_shape, x): """Broadcast tensor to the required shape.""" @@ -103,7 +94,8 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name): ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) + len(tensor_positions) + len(sequence_positions)) ellipsis_cnt = len(ellipsis_positions) - if (ellipsis_cnt == 0 and ellipsis_occupy_dims < 0) or (ellipsis_cnt > 0 and ellipsis_occupy_dims < 1): + # pylint: disable=chained-comparison + if ellipsis_occupy_dims < 0 and ellipsis_cnt >= 0: const_utils.raise_index_error("For the 'getitem Operator', the data_shape should be no less than the " "tuple index dims") @@ -155,14 +147,6 @@ def tensor_index_by_number(data, number_index): return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") -# TODO wait to remove after setitem by Yang Linfeng -def _tensor_index_by_bool(data, bool_index): - """Tensor getitem by a single bool value""" - if bool_index: - return F.expand_dims(data, 0) - return const_utils.make_tensor([], data.dtype, (0,) + F.shape(data)) - - def _tensor_index_by_integer(data, int_index): """Tensor getitem by a single integer number""" data_shape = F.shape(data) @@ -218,6 +202,31 @@ def tensor_index_by_tuple(data, tuple_index): return _tensor_getitem_by_tuple(data, tuple_index, op_name) +def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name): + """Tensor getitem by a tuple of tensor.""" + data_shape = F.shape(data) + tuple_index_len = len(tuple_index) + + indexes_types = hyper_map(F.dtype, tuple_index) + const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name) + tensor_index_shape = hyper_map(F.shape, tuple_index) + broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) + if 0 in broadcast_shape: + res_shape = broadcast_shape + if tuple_index_len < len(data_shape): + res_shape += data_shape[tuple_index_len:] + res = const_utils.make_tensor([], data.dtype, res_shape) + return res + + broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) + new_broadcast_tensors = () + for tensor in broadcast_tensors: + new_broadcast_tensors += (F.cast(tensor, mstype.int64),) + indices = stack(new_broadcast_tensors) + result = F.gather_nd(data, indices) + return result + + def _tensor_getitem_by_tuple_slice(data, tuple_index): """Tensor getitem by a tuple of slice""" data_shape = F.shape(data) @@ -284,8 +293,9 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name): final_index_tensors.append(slice_index_tensor) slice_cnt += 1 - indices = pack(final_index_tensors) - return F.gather_nd(data, indices) + indices = stack(final_index_tensors) + result = F.gather_nd(data, indices) + return result def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): @@ -299,7 +309,7 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): new_broadcast_tensors = () for tensor in broadcast_tensors: new_broadcast_tensors += (F.cast(tensor, mstype.int64),) - indices = pack(new_broadcast_tensors) + indices = stack(new_broadcast_tensors) return indices @@ -332,6 +342,11 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): tuple_index_new += (tensor_index,) tensor_indexes.append(tensor_index) elif i in slice_positions: + start, stop, _ = const_utils.slice_to_tuple(index) + start = const_utils.normalize_start(start, dim_size) + stop = const_utils.normalize_stop(stop, dim_size) + if start >= stop: + return None slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size) slice_shapes += (len(slice_ele_list_index),) tuple_index_new += (slice_ele_list_index,) @@ -354,7 +369,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): final_index_tensors.append(slice_index_tensor) slice_cnt += 1 - indices = pack(final_index_tensors) + indices = stack(final_index_tensors) return indices @@ -366,44 +381,76 @@ def _generate_updates_from_scalar(data, indices, value, op_type): return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) -def _generate_updates_from_tuple(data, index, value, op_type): - """Generate an updates tensor from a tuple.""" +def _generate_updates_from_sequence(data, index, value, op_type): + """Generate an updates tensor from a tuple, can only handle 1-D tensor/non-tensor mixtures.""" value_types = hyper_map(F.typeof, value) - data_dtype = F.dtype(data) - value_elements_type = const_utils.check_value_elements(data_dtype, value_types) + value_elements_type = const_utils.check_value_elements(value_types) if value_elements_type == const_utils.ALL_TENSOR: - value_shapes = hyper_map(F.shape, value) - shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM) - if shapes_same: - value = F.stack(value) - return _generate_updates_from_tensor(data, index, value, op_type) - - data_shape = F.shape(data) - index_shape = F.shape(index) - return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) + value = F.stack(value).astype(data.dtype) + elif value_elements_type == const_utils.NO_TENSOR: + value = const_utils.make_tensor(value, data.dtype) + else: + new_value = () + for ele in value: + ele = ele if isinstance(ele, Tensor) else const_utils.make_tensor(ele) + new_value += (ele,) + value = F.stack(new_value).astype(data.dtype) + if op_type == const_utils.SET_ITEM_BY_NON_TENSOR: + return value + return _generate_updates_from_tensor(data, index, value, op_type) def _generate_updates_from_tensor(data, index, value, op_type): """Generate an updates tensor from a tensor.""" - data_shape = F.shape(data) - index_shape = F.shape(index) - value_shape = F.shape(value) - data_dtype = F.dtype(data) - value_dtype = F.dtype(value) - updates_shape = value_shape - check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM) - if check_dtype_same: - updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type) - need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape) + value = value.astype(data.dtype) + updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type) + need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape) if need_broadcast: return _broadcast(updates_shape, value) return value -tensor_operator_registry.register("__getitem__", _tensor_getitem) +# Tensor getitem implementations are above this line, setitem implementations below. -tensor_operator_registry.register("__setitem__", _tensor_setitem) +def tensor_setitem_by_tensor(self, index, value): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_tensor_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_tensor_with_tensor(self, index, value) + return tensor_setitem_by_tensor_with_sequence(self, index, value) + + +def tensor_setitem_by_tuple(self, index, value): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_tuple_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_tuple_with_tensor(self, index, value) + return tensor_setitem_by_tuple_with_sequence(self, index, value) + + +def tensor_setitem_by_number(self, index, value): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_number_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_number_with_tensor(self, index, value) + return tensor_setitem_by_number_with_sequence(self, index, value) + + +def tensor_setitem_by_slice(self, index, value): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_slice_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_slice_with_tensor(self, index, value) + return tensor_setitem_by_slice_with_sequence(self, index, value) + + +def tensor_setitem_by_ellipsis(self, index, value): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_ellipsis_with_number(self, value) + if isinstance(value, Tensor): + return tensor_setitem_by_ellipsis_with_tensor(self, value) + return tensor_setitem_by_ellipsis_with_sequence(self, value) def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): @@ -469,17 +516,16 @@ def tensor_setitem_by_tensor_with_number(data, index, value): return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int") -def tensor_setitem_by_tensor_with_tuple(data, index, value): +def tensor_setitem_by_tensor_with_sequence(data, index, value): """Assigns the tensor by tensor with tuple value.""" index_dtype = F.dtype(index) const_utils.check_type_valid(index_dtype, (mstype.int32, mstype.int64), const_utils.TENSOR_SETITEM) - result = _tensor_setitem_by_tensor_with_tuple(data, index, value) - return result + return _tensor_setitem_by_tensor_with_sequence(data, index, value) def _tensor_indices_number(data, data_shape, index, indices, value): """Assigns a scalar value to the tensor.""" - data_size = F.size(data) + data_size = F.shape_mul(data.shape) data_dtype = F.dtype(data) indices_size = F.size(indices) indices_size = const_utils.check_indices(indices_size, index) @@ -493,9 +539,9 @@ def _tensor_indices_number(data, data_shape, index, indices, value): return F.select(condition, u, data) -def _tensor_setitem_by_tensor_with_tuple(data, index, value): +def _tensor_setitem_by_tensor_with_sequence(data, index, value): """Set a tensor item by a tensor with a tuple.""" - updates = _generate_updates_from_tuple(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) + updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) index = F.expand_dims(index, -1) return P.TensorScatterUpdate()(data, index, updates) @@ -507,6 +553,8 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value): if check_result: data_shape = F.shape(data) indices = const_utils.slice2indices(input_slice, data_shape) + if indices is False: + return data is_tuple_int = const_utils.tuple_element_is_int(input_slice) if is_tuple_int: indices = const_utils.integer_to_indices(input_slice, data_shape) @@ -516,6 +564,8 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value): def tensor_setitem_by_tuple_with_number(data, tuple_index, value): """Assigns the tensor by tuple with number value.""" + tuple_index = ignore_dim_expand(tuple_index) + if len(tuple_index) == 1: data[tuple_index[0]] = value return data @@ -533,13 +583,15 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): if int_cnt == const_utils.ALL_INT: tuple_index = const_utils.convert_int_to_slice(tuple_index) indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) + if indices is None: + return data updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) return P.TensorScatterUpdate()(data, indices, updates) def _tensor_indices_tensor(data, data_shape, index, indices, value): """Assigns a tensor value to the tensor.""" - data_size = F.size(data) + data_size = F.shape_mul(data.shape) data_dtype = F.dtype(data) indices_size = F.size(indices) indices_size = const_utils.check_indices(indices_size, index) @@ -548,7 +600,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): condition = F.reshape(condition_1d, data_shape) condition = F.cast(condition, mstype.bool_) value_fill = None - value_size = F.size(value) + value_size = value.size value_size = const_utils.check_indices_value_size(indices_size, value_size) if value_size == 1: @@ -559,7 +611,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): value_fill = F.reshape(value, (indices_size,)) value_1d = F.scatter_nd(indices, value_fill, (data_size,)) u = F.reshape(value_1d, data_shape) - return F.select(condition, u, data) + return F.select(condition, u.astype(data_dtype), data) def tensor_setitem_by_slice_with_tensor(data, input_slice, value): @@ -569,6 +621,8 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): if check_result: data_shape = F.shape(data) indices = const_utils.slice2indices(input_slice, data_shape) + if indices is False: + return data is_tuple_int = const_utils.tuple_element_is_int(input_slice) if is_tuple_int: indices = const_utils.integer_to_indices(input_slice, data_shape) @@ -576,8 +630,18 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): return result +def tensor_setitem_by_slice_with_sequence(data, input_slice, value): + """Assigns a list/tuple value to the tensor by slice.""" + value = _generate_updates_from_sequence(data, input_slice, value, const_utils.SET_ITEM_BY_NON_TENSOR) + return tensor_setitem_by_slice_with_tensor(data, input_slice, value) + + def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): """Assigns the tensor by tuple with tensor value.""" + value_shape = remove_ignored_dim(tuple_index, F.shape(value), F.rank(data)) + value = F.reshape(value, value_shape) + tuple_index = ignore_dim_expand(tuple_index) + if len(tuple_index) == 1: data[tuple_index[0]] = value return data @@ -600,31 +664,15 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): new_shape += value.shape value = F.reshape(value, new_shape) indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) + if indices is None: + return data updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) return P.TensorScatterUpdate()(data, indices, updates) -def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): - """Assigns the tensor by tuple with tuple of value.""" - if len(tuple_index) == 1: - data[tuple_index[0]] = value - return data - op_name = const_utils.TENSOR_GETITEM - tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) - data, tuple_index = _expand_data_dims(data, tuple_index) - - indexes_types = hyper_map(F.typeof, tuple_index) - contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) - - if contain_type == const_utils.ALL_TENSOR: - indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) - else: - int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) - if int_cnt == const_utils.ALL_INT: - tuple_index = const_utils.convert_int_to_slice(tuple_index) - indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) - updates = _generate_updates_from_tuple(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) - return P.TensorScatterUpdate()(data, indices, updates) +def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value): + value = _generate_updates_from_sequence(data, tuple_index, value, const_utils.SET_ITEM_BY_NON_TENSOR) + return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value) def tensor_setitem_by_number_with_number(data, index, value): @@ -634,6 +682,12 @@ def tensor_setitem_by_number_with_number(data, index, value): return _tensor_indices_number(data, data_shape, index, indices, value) +def tensor_setitem_by_number_with_sequence(data, index, value): + """Assigns a list/tuple value to the tensor by slice.""" + value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR) + return tensor_setitem_by_number_with_tensor(data, index, value) + + def tensor_setitem_by_number_with_tensor(data, index, value): """Assigns the tensor by number with tensor value.""" data_shape = F.shape(data) @@ -641,31 +695,46 @@ def tensor_setitem_by_number_with_tensor(data, index, value): return _tensor_indices_tensor(data, data_shape, index, indices, value) -def tensor_setitem_by_ellipsis_with_number(data, index, value): +def tensor_setitem_by_ellipsis_with_number(data, value): """Assigns the tensor by ellipsis with number value.""" data_shape = F.shape(data) data_dtype = F.dtype(data) return F.fill(data_dtype, data_shape, value) -def tensor_setitem_by_ellipsis_with_tensor(data, index, value): +def tensor_setitem_by_ellipsis_with_tensor(data, value): """Assigns the tensor by ellipsis with tensor value.""" - result = None data_shape = F.shape(data) data_dtype = F.dtype(data) - data_size = F.size(data) + value = value.astype(data_dtype) value_shape = F.shape(value) - value_size = F.size(value) - check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) - if check_result: - if data_size == value_size: - result = F.reshape(value, data_shape) - result = F.cast(result, data_dtype) - elif value_size == 1: - param1 = F.fill(data_dtype, data_shape, 1) - param2 = F.cast(value, data_dtype) - result = F.tensor_mul(param1, param2) - return result + source_shape = const_utils.get_source_shape(data_shape, value_shape) + value = F.reshape(value, source_shape) + value = _broadcast(data_shape, value) + data = F.cast(value, data_dtype) + return data + + +def tensor_setitem_by_ellipsis_with_sequence(data, value): + """Assigns a list/tuple value to the tensor by ellipsis.""" + value = _generate_updates_from_sequence(data, None, value, const_utils.SET_ITEM_BY_NON_TENSOR) + return tensor_setitem_by_ellipsis_with_tensor(data, value) + + +def tensor_setitem_by_bool(data, index, value): + """Assigns a value to the tensor by boolean.""" + data_shape = F.shape(data) + if not index: + data_shape = (0,) + data_shape + if not isinstance(value, Tensor): + value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR) + value_shape = F.shape(value) + source_shape = const_utils.get_source_shape(data_shape, value_shape) + if index: + value = F.reshape(value, source_shape) + value = _broadcast(data_shape, value) + data = value + return data def tensor_in_sequence(x, y): @@ -675,3 +744,79 @@ def tensor_in_sequence(x, y): if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype: result = F.logical_or(F.equal(x, i).all(), result) return result + + +def format_list_indices(list_indices, length): + """Convert list indices to tensor or tuple indices based on its contents.""" + indices_types = hyper_map(F.typeof, list_indices) + # If eyery element in list is bool, it's treated as 1-D bool tensor. + # If every element in list is int(not all bool), it's treated as int tensor. + if const_utils.judge_indexes_types(indices_types, mstype.int_type+(mstype.bool_,)): + list_indices = const_utils.transform_sequence_index(list_indices, length, const_utils.TENSOR_SETITEM) + return const_utils.make_tensor(list_indices) + # If list contains other types(.../list/tuple/None), it's treated as a tuple + return const_utils.deep_tuple(list_indices) + + +def format_tuple_indices(tuple_indices): + """ + Format tuple indices by unpacking high-dimension tuple and removing expand + dimension signs(Bool and None). + """ + res = () + for i in tuple_indices: + if isinstance(i, (list, tuple)): + res += (const_utils.unpack(i),) + else: + res += (i,) + return res + + +def tuple_indices_have_false(tuple_indices): + """Returns True if tuple_indices contains False.""" + for i in tuple_indices: + if i is False: + return True + return False + + +def ignore_dim_expand(idx): + """Filters flags for dimension expansion from idx.""" + res = () + for i in idx: + if not i is True and not i is None: + res += (i,) + if not res: + res = (True,) + return res + + +def remove_ignored_dim(idx, value_shape, data_rank): + """Removes dimensions in value that correspond to dimension expansion flags in index.""" + has_ellipsis = False + has_true = False + cnt_trailing_expanded = 0 + cnt_not_dim_expand = 0 + for i in idx: + if not i is True and not i is None: + cnt_not_dim_expand += 1 + if const_utils.is_ellipsis(i): + has_ellipsis = True + elif has_ellipsis: + if i is None: + cnt_trailing_expanded += 1 + elif i is True and not has_true: + has_true = True + if has_true and cnt_not_dim_expand + 1 < data_rank: + cnt_trailing_expanded += 1 + + if cnt_trailing_expanded == 0: + return value_shape + value_expanded_pos = len(value_shape) - cnt_trailing_expanded + value_expanded_not_unit = False + for i in value_shape[value_expanded_pos:]: + if i != 1: + value_expanded_not_unit = True + if value_expanded_pos < 0 or value_expanded_not_unit: + const_utils.raise_value_error('shape mismatch') + return value_shape[:value_expanded_pos] diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index a85984f647..598df06c49 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -43,6 +43,7 @@ TENSOR_GETITEM = "tensor getitem" SET_ITEM_BY_ONE_TENSOR = 0 SET_ITEM_BY_TUPLE_OF_TENSOR = 1 +SET_ITEM_BY_NON_TENSOR = 2 @constexpr @@ -74,10 +75,85 @@ def make_empty_slice(): @constexpr -def make_tensor(data, data_type=mstype.int64, data_shape=None): +def _deep_list(array_like): + """convert nested tuple/list mixtures to pure nested list""" + if isinstance(array_like, (list, tuple)): + return list(map(_deep_list, array_like)) + return array_like + + +@constexpr +def deep_tuple(array_like): + """convert nested tuple/list mixtures to pure nested tuple""" + if isinstance(array_like, (list, tuple)): + return tuple(map(deep_tuple, array_like)) + return array_like + + +def _deep_tensor_to_nparray(array_like): + """ + convert a nested list of tensor to nested list of np_array. + + Args: + array_like(list(tensor)): In any format of nested lists that may contain + tensors. + + Returns: + array_like(list(np_array)): Formatted array that can be directly processed + by numpy.array(), with all tensor elements converted to numpy_array. + """ + # Recursively check whether each element is a tensor or not, if is tensor, + # convert it to a numpy array in place + if isinstance(array_like, Tensor): + return array_like.asnumpy() + + if isinstance(array_like, list): + for idx, value in enumerate(array_like): + array_like[idx] = _deep_tensor_to_nparray(value) + + return array_like + + +@constexpr +def make_tensor(a, dtype=mstype.int32, data_shape=None): + """ + Converts the input to tensor. + + This function converts tensors from an array-like object. + + Args: + a (Union[int, float, bool, list, tuple]): Input data, in any form that can + be converted to a `Tensor`. + dtype (:class:`mindspore.dtype`): Designated tensor dtype. + + Returns: + Tensor, generated tensor with the specified dtype. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If input `a` has different sizes at different dimensions. + """ + if data_shape: - return Tensor(np.zeros(data_shape), data_type) - return Tensor(data, data_type) + return Tensor(np.zeros(data_shape), dtype) + + if not isinstance(a, (list, tuple, int, float, bool)): + raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`") + + if isinstance(a, (list, tuple)): + # Convert all tuple/nested tuples to lists + a = _deep_list(a) + # Convert all tensor sub-elements to numpy arrays + a = _deep_tensor_to_nparray(a) + a = np.asarray(a) + if a.dtype is np.dtype('object'): + raise ValueError('Input array must have the same size across all dimensions.') + + if isinstance(a, np.ndarray): + if a.dtype is np.dtype('object'): + raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") + + return Tensor(a, dtype) @constexpr @@ -88,12 +164,20 @@ def judge_data_rank(data_rank, min_data_rank=0, max_data_rank=8): @constexpr -def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): - """Checks the shape and size of the sensor and value.""" - if data_shape == value_shape or data_size == value_size or value_size == 1: - return True - raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format( - value_shape, data_shape)) +def get_source_shape(data_shape, value_shape): + """Returns the shape of value that will be used to broadcast against data.""" + cannot_broadcast = False + source_shape = value_shape + for i, j in zip(reversed(data_shape), reversed(value_shape)): + if j not in (1, i): + cannot_broadcast = True + for i in range(len(value_shape) - len(data_shape)): + source_shape = data_shape + if value_shape[i] != 1: + cannot_broadcast = True + if cannot_broadcast: + raise ValueError(f'could not broadcast input array from shape {value_shape} to {data_shape}') + return source_shape @constexpr @@ -288,8 +372,10 @@ def slice2indices(input_slices, shape): begin, end, strides = slice_expand(input_slices, shape) np_r = [] for i, element in enumerate(shape): - s = begin[i] if (begin[i] >= 0) else (element + begin[i]) - e = end[i] if (end[i] >= 0) else (element + end[i]) + s = normalize_start(begin[i], element) + e = normalize_stop(end[i], element) + if s >= e: + return False np_r.append(np.r_[s:e:strides[i]]) # Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape) np_ix = np.ix_(*np_r) @@ -364,29 +450,17 @@ def tuple_index_type_cnt(types, op_name): @constexpr -def check_value_elements(data_dtype, types): +def check_value_elements(types): """Judges the type of all elements of the tuple.""" - tensors_number = 0 - scalars_number = 0 - for i, ele in enumerate(types): + tensor_number = 0 + for ele in types: if isinstance(ele, mstype.tensor_type): - ele_dtype = ele.element_type() - if data_dtype == ele_dtype: - tensors_number += 1 - else: - raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' " - f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.") - elif mstype.dtype_to_pytype(ele) == mstype.dtype_to_pytype(data_dtype): - scalars_number += 1 - else: - raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in " - f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.") - if tensors_number == len(types): + tensor_number += 1 + if tensor_number == 0: + return NO_TENSOR + if tensor_number == len(types): return ALL_TENSOR - if scalars_number == len(types): - return ALL_SCALAR - raise TypeError( - f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") + return CONTAIN_TENSOR @constexpr @@ -528,10 +602,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty updates_shape = indices_shape + data_shape[1:] else: updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:] - if isinstance(value, mstype.dtype_to_pytype(data_dtype)): - return Tensor(np.full(updates_shape, value), dtype=data_dtype) - raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'" - f" is not consistent with the assigned tensor data type {data_dtype}.") + return Tensor(np.full(updates_shape, value), dtype=data_dtype) @constexpr @@ -716,3 +787,46 @@ def mstype_eq(x, y): def scalar_to_tensor(x): """Convert a scalar to a tensor""" return Tensor(x) + + +@constexpr +def unpack(x): + if isinstance(x, (tuple, list)) and len(x) == 1: + return unpack(x[0]) + return x + + +@constexpr +def slice_to_tuple(s): + return (s.start, s.stop, s.step) + + +@constexpr +def normalize_start(start, dim_size): + """ + Normalize `start` according to the number of dimensions (`dim_size`). + If the number of dimensions is not given, return the original input directly. + """ + if start is None: + return 0 + if start < 0: + return 0 if start < -dim_size else start % dim_size + return start if start < dim_size else dim_size + + +@constexpr +def normalize_stop(stop, dim_size): + """ + Normalize `stop` according to the number of dimensions (`dim_size`). + If the number of dimensions is not given, return the original input directly. + """ + if stop is None: + return dim_size + if stop < 0: + return 0 if stop < -dim_size else stop % dim_size + return stop if stop < dim_size else dim_size + + +@constexpr +def is_ellipsis(x): + return x is Ellipsis diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index 2b90ef5be8..5cdfefd8c6 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -18,6 +18,7 @@ from . import _compile_utils as compile_utils from ... import functional as F from ...composite import base +from ....common import Tensor setitem = base.MultitypeFuncGraph('setitem') @@ -213,6 +214,9 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ + if compile_utils.tuple_indices_have_false(tuple_index): + return data + tuple_index = compile_utils.format_tuple_indices(tuple_index) return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value) @@ -234,6 +238,9 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): Outputs: Tensor, element type and shape is same as data. """ + if compile_utils.tuple_indices_have_false(tuple_index): + return data + tuple_index = compile_utils.format_tuple_indices(tuple_index) return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value) @@ -246,21 +253,49 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): Syntax support: A[B, C, D] = U. Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. 2) A B and C could be broadcast. - 3) U is a Tensor. + 3) U is a Tuple. Inputs: data (Tensor): Assigned tensor. index (Tuple): A tuple of tensor, these tensor could be broadcast. - value (Tensor): Assignment tensor, should has the same data type as 'data'. + value (Tuple): Assignment tuple. Outputs: Tensor, element type and shape is same as data. """ - return compile_utils.tensor_setitem_by_tuple_with_tuple(data, tuple_index, value) + if compile_utils.tuple_indices_have_false(tuple_index): + return data + tuple_index = compile_utils.format_tuple_indices(tuple_index) + return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value) + + +@setitem.register("Tensor", "Tuple", "List") +def _tensor_setitem_by_tuple_with_list(data, tuple_index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[B, C, D] = U. + Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors. + 2) A B and C could be broadcast. + 3) U is a List. + + Inputs: + data (Tensor): Assigned tensor. + index (Tuple): A tuple of tensor, these tensor could be broadcast. + value (List): Assignment tuple. + + Outputs: + Tensor, element type and shape is same as data. + """ + if compile_utils.tuple_indices_have_false(tuple_index): + return data + tuple_index = compile_utils.format_tuple_indices(tuple_index) + return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value) @setitem.register("Tensor", "Tensor", "Tuple") -def _tensor_setitem_by_tensor_v2(data, index, value): +def _tensor_setitem_by_tensor_with_tuple(data, index, value): """ Tensor assignment. @@ -272,11 +307,27 @@ def _tensor_setitem_by_tensor_v2(data, index, value): Outputs: Tensor, element type and shape is same as data. """ - return compile_utils.tensor_setitem_by_tensor_with_tuple(data, index, value) + return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) + + +@setitem.register("Tensor", "Tensor", "List") +def _tensor_setitem_by_tensor_with_list(data, index, value): + """ + Tensor assignment. + + Inputs: + data (Tensor): Assigned tensor. + index (Tensor): Tensor of bool type. + value (List): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) @setitem.register("Tensor", "Slice", "Tensor") -def _tensor_setitem_with_slice_v3(data, input_slice, value): +def _tensor_setitem_by_slice_with_tensor(data, input_slice, value): """ Tensor assignment. @@ -298,7 +349,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value): @setitem.register("Tensor", "Slice", "Number") -def _tensor_setitem_with_slice_v1(data, input_slice, value): +def _tensor_setitem_by_slice_with_number(data, input_slice, value): """ Tensor assignment. @@ -319,21 +370,326 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value): return compile_utils.tensor_setitem_by_slice_with_number(data, input_slice, value) +@setitem.register("Tensor", "Slice", "List") +def _tensor_setitem_by_slice_with_list(data, input_slice, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Slice] = u + Restraint condition: A is a Tensor. + Slice like "1:3" + u is a list + + Inputs: + data (Tensor): Assigned tensor. + input_slice (Slice): slice expression. + value (List): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_slice_with_sequence(data, input_slice, value) + + +@setitem.register("Tensor", "Slice", "Tuple") +def _tensor_setitem_by_slice_with_tuple(data, input_slice, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Slice] = u + Restraint condition: A is a Tensor. + Slice like "1:3" + u is a tuple + + Inputs: + data (Tensor): Assigned tensor. + input_slice (Slice): slice expression. + value (Tuple): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_slice_with_sequence(data, input_slice, value) + + + @setitem.register("Tensor", "Number", "Number") -def _tensor_setitem_with_int_v1(data, index, value): +def _tensor_setitem_by_number_with_number(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Number] = u + Restraint condition: A is a Tensor. + u is a Number. + + Inputs: + data (Tensor): Assigned tensor. + index (Number): An integer index. + value (Tuple): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + if isinstance(index, bool): + return compile_utils.tensor_setitem_by_bool(data, index, value) return compile_utils.tensor_setitem_by_number_with_number(data, index, value) @setitem.register("Tensor", "Number", "Tensor") -def _tensor_setitem_with_int_v2(data, index, value): +def _tensor_setitem_by_number_with_tensor(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Number] = u + Restraint condition: A is a Tensor. + u is a Tensor. + + Inputs: + data (Tensor): Assigned tensor. + index (Number): An integer index. + value (Tensor): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + if isinstance(index, bool): + return compile_utils.tensor_setitem_by_bool(data, index, value) return compile_utils.tensor_setitem_by_number_with_tensor(data, index, value) +@setitem.register("Tensor", "Number", "Tuple") +def _tensor_setitem_by_number_with_tuple(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Number] = u + Restraint condition: A is a Tensor. + u is a Tuple, with all elements equal in length. + + Inputs: + data (Tensor): Assigned tensor. + index (Number): An integer index. + value (Tuple): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + if isinstance(index, bool): + return compile_utils.tensor_setitem_by_bool(data, index, value) + return compile_utils.tensor_setitem_by_number_with_sequence(data, index, value) + + +@setitem.register("Tensor", "Number", "List") +def _tensor_setitem_by_number_with_list(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[Number] = u + Restraint condition: A is a Tensor. + u is a List, with all elements equal in length. + + Inputs: + data (Tensor): Assigned tensor. + index (Number): An integer index. + value (List): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + if isinstance(index, bool): + return compile_utils.tensor_setitem_by_bool(data, index, value) + return compile_utils.tensor_setitem_by_number_with_sequence(data, index, value) + + @setitem.register("Tensor", "Ellipsis", "Number") -def _tensor_setitem_with_ellipsis_v1(data, index, value): - return compile_utils.tensor_setitem_by_ellipsis_with_number(data, index, value) +def _tensor_setitem_by_ellipsis_with_number(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[...] = u + Restraint condition: A is a Tensor. + u is a Number. + Inputs: + data (Tensor): Assigned tensor. + index (Ellipsis): Index is ``...``. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_ellipsis_with_number(data, value) @setitem.register("Tensor", "Ellipsis", "Tensor") -def _tensor_setitem_with_ellipsis_v2(data, index, value): - return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, index, value) +def _tensor_setitem_by_ellipsis_with_tensor(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[...] = u + Restraint condition: A is a Tensor. + u is a Tensor. + Inputs: + data (Tensor): Assigned tensor. + index (Ellipsis): Index is ``...``. + value (Tensor): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, value) + + +@setitem.register("Tensor", "Ellipsis", "List") +def _tensor_setitem_by_ellipsis_with_list(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[...] = u + Restraint condition: A is a Tensor. + u is a List, with all elements equal in length. + Inputs: + data (Tensor): Assigned tensor. + index (Ellipsis): Index is ``...``. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value) + + +@setitem.register("Tensor", "Ellipsis", "Tuple") +def _tensor_setitem_by_ellipsis_with_tuple(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[...] = u + Restraint condition: A is a Tensor. + u is a Tuple, with all elements equal in length. + Inputs: + data (Tensor): Assigned tensor. + index (Ellipsis): Index is ``...``. + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value) + + +@setitem.register("Tensor", "List", "Number") +def _tensor_setitem_by_list_with_number(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[List] = u + Restraint condition: A is a Tensor. + u is a Number. + Inputs: + data (Tensor): Assigned tensor. + index (List). + value (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + # list indices will be converted to tuple or tensor based on its contents. + index = compile_utils.format_list_indices(index, data.shape[0]) + if isinstance(index, Tensor): + return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value) + if compile_utils.tuple_indices_have_false(index): + return data + index = compile_utils.format_tuple_indices(index) + return compile_utils.tensor_setitem_by_tuple_with_number(data, index, value) + + +@setitem.register("Tensor", "List", "Tensor") +def _tensor_setitem_by_list_with_tensor(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[List] = u + Restraint condition: A is a Tensor. + u is a Tensor. + Inputs: + data (Tensor): Assigned tensor. + index (List). + value (Tensor): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + # list indices will be converted to tuple or tensor based on its contents. + index = compile_utils.format_list_indices(index, data.shape[0]) + if isinstance(index, Tensor): + return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value) + if compile_utils.tuple_indices_have_false(index): + return data + index = compile_utils.format_tuple_indices(index) + return compile_utils.tensor_setitem_by_tuple_with_tensor(data, index, value) + + +@setitem.register("Tensor", "List", "Tuple") +def _tensor_setitem_by_list_with_tuple(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[List] = u + Restraint condition: A is a Tensor. + u is a Tuple, with all elements equal in length. + Inputs: + data (Tensor): Assigned tensor. + index (List). + value (Tuple): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + # list indices will be converted to tuple or tensor based on its contents. + index = compile_utils.format_list_indices(index, data.shape[0]) + if isinstance(index, Tensor): + return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) + if compile_utils.tuple_indices_have_false(index): + return data + index = compile_utils.format_tuple_indices(index) + return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value) + + +@setitem.register("Tensor", "List", "List") +def _tensor_setitem_by_list_with_list(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[List] = u + Restraint condition: A is a Tensor. + u is a List, with all elements equal in length. + Inputs: + data (Tensor): Assigned tensor. + index (List). + value (List): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + # list indices will be converted to tuple or tensor based on its contents. + index = compile_utils.format_list_indices(index, data.shape[0]) + if isinstance(index, Tensor): + return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value) + if compile_utils.tuple_indices_have_false(index): + return data + index = compile_utils.format_tuple_indices(index) + return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value) diff --git a/tests/st/pynative/test_tensor_index.py b/tests/st/pynative/test_tensor_index.py index 8221fa54ef..ac30b6ae15 100644 --- a/tests/st/pynative/test_tensor_index.py +++ b/tests/st/pynative/test_tensor_index.py @@ -321,7 +321,7 @@ def test_setitem_by_mixed_tensors_2(): assert np.all(out.asnumpy() == (input_np + const)) -class TensorGetItemByMixedTensorsTypeError(Cell): +class TensorGetItemByMixedTensorsIndexError(Cell): def construct(self, x, index_0, index_1): ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]] return ret @@ -331,8 +331,8 @@ def test_getitem_by_mixedtensor_exception(): input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32) index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32) index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32) - net1 = TensorGetItemByMixedTensorsTypeError() - with pytest.raises(TypeError): + net1 = TensorGetItemByMixedTensorsIndexError() + with pytest.raises(IndexError): net1(input_ms, index_0, index_1) diff --git a/tests/st/pynative/test_tensor_setitem.py b/tests/st/pynative/test_tensor_setitem.py new file mode 100644 index 0000000000..b272371ecd --- /dev/null +++ b/tests/st/pynative/test_tensor_setitem.py @@ -0,0 +1,215 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_tensor_setitem """ +import numpy as onp +import pytest + +from mindspore import Tensor, context +from mindspore.nn import Cell + + +def setup_module(): + context.set_context(mode=context.GRAPH_MODE) + + +def setup_testcase(input_np, case_fn): + input_ms = Tensor(input_np) + + class TensorSetItem(Cell): + def construct(self, x): + return case_fn(x) + + class NumpySetItem(): + def __call__(self, x): + return case_fn(x) + + out_ms = TensorSetItem()(input_ms) + out_np = NumpySetItem()(input_np) + assert onp.all(out_ms.asnumpy() == out_np) + + +class TensorSetItemByList(Cell): + def construct(self, x): + x[[0, 1], [1, 2], [1, 3]] = [3, 4] + x[([0, 1], [0, 2], [1, 1])] = [10, 5] + x[[0, 1], ..., [0, 1]] = 4 + return x + +class NumpySetItemByList(): + def __call__(self, x): + x[[0, 1], [1, 2], [1, 3]] = [3, 4] + x[([0, 1], [0, 2], [1, 1])] = [10, 5] + x[[0, 1], ..., [0, 1]] = 4 + return x + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_by_list(): + x = onp.ones((2, 3, 4), dtype=onp.float32) + def cases(x): + x[[0, 1], [1, 2], [1, 3]] = [3, 4] + x[([0, 1], [0, 2], [1, 1])] = [10, 5] + x[[0, 1], ..., [0, 1]] = 4 + return x + setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_with_sequence(): + x = onp.ones((2, 3, 4), dtype=onp.float32) + def cases(x): + x[...] = [3] + x[..., 1] = ([1, 2, 3], [4, 5, 6]) + x[0] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11]) + x[1:2] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11]) + return x + setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_dtype(): + x = onp.ones((2, 3, 4), dtype=onp.float32) + def cases(x): + x[...] = 3 + x[..., 1] = 3.0 + x[0] = True + x[1:2] = ((0, False, 2, 3), (4.0, 5, 6, 7), [True, 9, 10, 11]) + return x + setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_by_tuple_with_int(): + x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) + def cases(x): + x[..., 2, False, 1] = -1 + x[0, True, 0, None, True] = -2 + x[0, ..., None] = -3 + x[..., 0, None, 1, True, True, None] = -4 + return x + setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_by_tuple_with_list(): + x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) + def cases(x): + x[..., 2, False, 1] = [-1] + x[0, True, 0, None, True] = [-2, -2, -2, -2] + x[0, ..., None] = [[-3], [-3], [-3], [-3]] + x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]] + return x + setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_by_nested_unit_list(): + x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32) + def cases(x): + x[[[[0]]], True] = -1 + x[[1], ..., [[[[2]]]]] = -2 + x[0, [[[2]]], [1]] = -3 + return x + setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_with_broadcast(): + x = onp.arange(2*3*4*5*6).reshape(2, 3, 4, 5, 6).astype(onp.float32) + v1 = onp.full((1, 4, 5), -1).tolist() + v2 = onp.full((4, 1, 6), -2).tolist() + def cases(x): + x[..., 4] = v1 + x[0, 2] = v2 + x[1, 0, ..., 3] = [[-3], [-3], [-3], [-3]] + x[0, ..., 1, 3, 5] = -4 + return x + setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_mul_by_scalar(): + x = onp.ones((4, 5), dtype=onp.float32) + def cases(x): + x[1, :] = x[1, :]*2 + x[:, 2] = x[:, 3]*3.0 + return x + setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_by_slice(): + x = onp.ones((3, 4, 5), dtype=onp.float32) + def cases(x): + x[1:2] = 2 + x[-3:1] = 3 + x[-10:3:2] = 4 + x[5:0:3] = 5 + x[5:5:5] = 6 + x[-1:2] = 7 + return x + setup_testcase(x, cases) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_setitem_by_tuple_of_slices(): + x = onp.ones((3, 4, 5), dtype=onp.float32) + def cases(x): + x[1:2, 2] = 2 + x[0, -4:1] = 3 + x[1, -10:3:2] = 4 + x[5:0:3, 3] = 5 + x[1:1, 2:2] = 6 + return x + setup_testcase(x, cases)