Support single bracket setitem

pull/14166/head
yanglf1121 4 years ago
parent b52f0ced25
commit 15820776fc

File diff suppressed because it is too large Load Diff

@ -43,6 +43,7 @@ TENSOR_GETITEM = "tensor getitem"
SET_ITEM_BY_ONE_TENSOR = 0
SET_ITEM_BY_TUPLE_OF_TENSOR = 1
SET_ITEM_BY_NON_TENSOR = 2
@constexpr
@ -74,10 +75,85 @@ def make_empty_slice():
@constexpr
def make_tensor(data, data_type=mstype.int64, data_shape=None):
def _deep_list(array_like):
"""convert nested tuple/list mixtures to pure nested list"""
if isinstance(array_like, (list, tuple)):
return list(map(_deep_list, array_like))
return array_like
@constexpr
def deep_tuple(array_like):
"""convert nested tuple/list mixtures to pure nested tuple"""
if isinstance(array_like, (list, tuple)):
return tuple(map(deep_tuple, array_like))
return array_like
def _deep_tensor_to_nparray(array_like):
"""
convert a nested list of tensor to nested list of np_array.
Args:
array_like(list(tensor)): In any format of nested lists that may contain
tensors.
Returns:
array_like(list(np_array)): Formatted array that can be directly processed
by numpy.array(), with all tensor elements converted to numpy_array.
"""
# Recursively check whether each element is a tensor or not, if is tensor,
# convert it to a numpy array in place
if isinstance(array_like, Tensor):
return array_like.asnumpy()
if isinstance(array_like, list):
for idx, value in enumerate(array_like):
array_like[idx] = _deep_tensor_to_nparray(value)
return array_like
@constexpr
def make_tensor(a, dtype=mstype.int32, data_shape=None):
"""
Converts the input to tensor.
This function converts tensors from an array-like object.
Args:
a (Union[int, float, bool, list, tuple]): Input data, in any form that can
be converted to a `Tensor`.
dtype (:class:`mindspore.dtype`): Designated tensor dtype.
Returns:
Tensor, generated tensor with the specified dtype.
Raises:
TypeError: If input arguments have types not specified above.
ValueError: If input `a` has different sizes at different dimensions.
"""
if data_shape:
return Tensor(np.zeros(data_shape), data_type)
return Tensor(data, data_type)
return Tensor(np.zeros(data_shape), dtype)
if not isinstance(a, (list, tuple, int, float, bool)):
raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`")
if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
a = np.asarray(a)
if a.dtype is np.dtype('object'):
raise ValueError('Input array must have the same size across all dimensions.')
if isinstance(a, np.ndarray):
if a.dtype is np.dtype('object'):
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
return Tensor(a, dtype)
@constexpr
@ -88,12 +164,20 @@ def judge_data_rank(data_rank, min_data_rank=0, max_data_rank=8):
@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
if data_shape == value_shape or data_size == value_size or value_size == 1:
return True
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(
value_shape, data_shape))
def get_source_shape(data_shape, value_shape):
"""Returns the shape of value that will be used to broadcast against data."""
cannot_broadcast = False
source_shape = value_shape
for i, j in zip(reversed(data_shape), reversed(value_shape)):
if j not in (1, i):
cannot_broadcast = True
for i in range(len(value_shape) - len(data_shape)):
source_shape = data_shape
if value_shape[i] != 1:
cannot_broadcast = True
if cannot_broadcast:
raise ValueError(f'could not broadcast input array from shape {value_shape} to {data_shape}')
return source_shape
@constexpr
@ -288,8 +372,10 @@ def slice2indices(input_slices, shape):
begin, end, strides = slice_expand(input_slices, shape)
np_r = []
for i, element in enumerate(shape):
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
e = end[i] if (end[i] >= 0) else (element + end[i])
s = normalize_start(begin[i], element)
e = normalize_stop(end[i], element)
if s >= e:
return False
np_r.append(np.r_[s:e:strides[i]])
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
np_ix = np.ix_(*np_r)
@ -364,29 +450,17 @@ def tuple_index_type_cnt(types, op_name):
@constexpr
def check_value_elements(data_dtype, types):
def check_value_elements(types):
"""Judges the type of all elements of the tuple."""
tensors_number = 0
scalars_number = 0
for i, ele in enumerate(types):
tensor_number = 0
for ele in types:
if isinstance(ele, mstype.tensor_type):
ele_dtype = ele.element_type()
if data_dtype == ele_dtype:
tensors_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
elif mstype.dtype_to_pytype(ele) == mstype.dtype_to_pytype(data_dtype):
scalars_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
if tensors_number == len(types):
tensor_number += 1
if tensor_number == 0:
return NO_TENSOR
if tensor_number == len(types):
return ALL_TENSOR
if scalars_number == len(types):
return ALL_SCALAR
raise TypeError(
f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
return CONTAIN_TENSOR
@constexpr
@ -528,10 +602,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
updates_shape = indices_shape + data_shape[1:]
else:
updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
f" is not consistent with the assigned tensor data type {data_dtype}.")
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
@constexpr
@ -716,3 +787,46 @@ def mstype_eq(x, y):
def scalar_to_tensor(x):
"""Convert a scalar to a tensor"""
return Tensor(x)
@constexpr
def unpack(x):
if isinstance(x, (tuple, list)) and len(x) == 1:
return unpack(x[0])
return x
@constexpr
def slice_to_tuple(s):
return (s.start, s.stop, s.step)
@constexpr
def normalize_start(start, dim_size):
"""
Normalize `start` according to the number of dimensions (`dim_size`).
If the number of dimensions is not given, return the original input directly.
"""
if start is None:
return 0
if start < 0:
return 0 if start < -dim_size else start % dim_size
return start if start < dim_size else dim_size
@constexpr
def normalize_stop(stop, dim_size):
"""
Normalize `stop` according to the number of dimensions (`dim_size`).
If the number of dimensions is not given, return the original input directly.
"""
if stop is None:
return dim_size
if stop < 0:
return 0 if stop < -dim_size else stop % dim_size
return stop if stop < dim_size else dim_size
@constexpr
def is_ellipsis(x):
return x is Ellipsis

File diff suppressed because it is too large Load Diff

@ -321,7 +321,7 @@ def test_setitem_by_mixed_tensors_2():
assert np.all(out.asnumpy() == (input_np + const))
class TensorGetItemByMixedTensorsTypeError(Cell):
class TensorGetItemByMixedTensorsIndexError(Cell):
def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
return ret
@ -331,8 +331,8 @@ def test_getitem_by_mixedtensor_exception():
input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32)
index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32)
index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)
net1 = TensorGetItemByMixedTensorsTypeError()
with pytest.raises(TypeError):
net1 = TensorGetItemByMixedTensorsIndexError()
with pytest.raises(IndexError):
net1(input_ms, index_0, index_1)

@ -0,0 +1,215 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_tensor_setitem """
import numpy as onp
import pytest
from mindspore import Tensor, context
from mindspore.nn import Cell
def setup_module():
context.set_context(mode=context.GRAPH_MODE)
def setup_testcase(input_np, case_fn):
input_ms = Tensor(input_np)
class TensorSetItem(Cell):
def construct(self, x):
return case_fn(x)
class NumpySetItem():
def __call__(self, x):
return case_fn(x)
out_ms = TensorSetItem()(input_ms)
out_np = NumpySetItem()(input_np)
assert onp.all(out_ms.asnumpy() == out_np)
class TensorSetItemByList(Cell):
def construct(self, x):
x[[0, 1], [1, 2], [1, 3]] = [3, 4]
x[([0, 1], [0, 2], [1, 1])] = [10, 5]
x[[0, 1], ..., [0, 1]] = 4
return x
class NumpySetItemByList():
def __call__(self, x):
x[[0, 1], [1, 2], [1, 3]] = [3, 4]
x[([0, 1], [0, 2], [1, 1])] = [10, 5]
x[[0, 1], ..., [0, 1]] = 4
return x
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_list():
x = onp.ones((2, 3, 4), dtype=onp.float32)
def cases(x):
x[[0, 1], [1, 2], [1, 3]] = [3, 4]
x[([0, 1], [0, 2], [1, 1])] = [10, 5]
x[[0, 1], ..., [0, 1]] = 4
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_with_sequence():
x = onp.ones((2, 3, 4), dtype=onp.float32)
def cases(x):
x[...] = [3]
x[..., 1] = ([1, 2, 3], [4, 5, 6])
x[0] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
x[1:2] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_dtype():
x = onp.ones((2, 3, 4), dtype=onp.float32)
def cases(x):
x[...] = 3
x[..., 1] = 3.0
x[0] = True
x[1:2] = ((0, False, 2, 3), (4.0, 5, 6, 7), [True, 9, 10, 11])
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_tuple_with_int():
x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
def cases(x):
x[..., 2, False, 1] = -1
x[0, True, 0, None, True] = -2
x[0, ..., None] = -3
x[..., 0, None, 1, True, True, None] = -4
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_tuple_with_list():
x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
def cases(x):
x[..., 2, False, 1] = [-1]
x[0, True, 0, None, True] = [-2, -2, -2, -2]
x[0, ..., None] = [[-3], [-3], [-3], [-3]]
x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]]
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_nested_unit_list():
x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
def cases(x):
x[[[[0]]], True] = -1
x[[1], ..., [[[[2]]]]] = -2
x[0, [[[2]]], [1]] = -3
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_with_broadcast():
x = onp.arange(2*3*4*5*6).reshape(2, 3, 4, 5, 6).astype(onp.float32)
v1 = onp.full((1, 4, 5), -1).tolist()
v2 = onp.full((4, 1, 6), -2).tolist()
def cases(x):
x[..., 4] = v1
x[0, 2] = v2
x[1, 0, ..., 3] = [[-3], [-3], [-3], [-3]]
x[0, ..., 1, 3, 5] = -4
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_mul_by_scalar():
x = onp.ones((4, 5), dtype=onp.float32)
def cases(x):
x[1, :] = x[1, :]*2
x[:, 2] = x[:, 3]*3.0
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_slice():
x = onp.ones((3, 4, 5), dtype=onp.float32)
def cases(x):
x[1:2] = 2
x[-3:1] = 3
x[-10:3:2] = 4
x[5:0:3] = 5
x[5:5:5] = 6
x[-1:2] = 7
return x
setup_testcase(x, cases)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_tuple_of_slices():
x = onp.ones((3, 4, 5), dtype=onp.float32)
def cases(x):
x[1:2, 2] = 2
x[0, -4:1] = 3
x[1, -10:3:2] = 4
x[5:0:3, 3] = 5
x[1:1, 2:2] = 6
return x
setup_testcase(x, cases)
Loading…
Cancel
Save