!595 Tensor assign with int or tuple(int) index

Merge pull request !595 from candanzg/tensor_assign_with_integer
pull/595/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit a02eb240e9

@ -15,6 +15,7 @@
"""constexpr util""" """constexpr util"""
from functools import reduce
import numpy as np import numpy as np
from ...primitive import constexpr from ...primitive import constexpr
from ....common.tensor import Tensor from ....common.tensor import Tensor
@ -23,26 +24,27 @@ from ...._extends.utils import Slice
@constexpr @constexpr
def check_equal(param1, param2, msg="{},{}"): def check_equal(param1, param2, msg="{},{}"):
"""Checks whether the two parameters are equal or not."""
if param1 != param2: if param1 != param2:
raise ValueError(msg.format(param1, param2)) raise ValueError(msg.format(param1, param2))
return param1 return param1
@constexpr @constexpr
def check_tensor_setitem_index(index, element_type=None): def check_tensor_setitem_index(index, element_type=None):
"""Check tuple index type of tensor assignment.""" """Checks tuple index type of tensor assignment."""
if index is None: if index is None:
raise ValueError("Tensor's index cannot be None.") raise ValueError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u # eg. Tensor[Slice] = u
if isinstance(index, Slice): if isinstance(index, Slice):
return True return True
# eg. Tensor[Tuple] = u # eg. Tensor[tuple] = u
if isinstance(index, tuple): if isinstance(index, tuple):
if not index: if not index:
raise ValueError("Tensor's index cannot be empty.") raise ValueError("Tensor's index cannot be empty.")
# eg. Tensor[Tuple(Slice...)] = u # eg. Tensor[tuple(Slice...)] = u
if not isinstance(index[0], Slice): if isinstance(index[0], (Slice, int)):
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0]))) return True
return True raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u # eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor: if index == mstype.tensor:
if element_type is None or element_type != mstype.bool_: if element_type is None or element_type != mstype.bool_:
@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None):
@constexpr @constexpr
def is_same_type(inst, type_): def is_same_type(inst, type_):
""" """
Check whether an object is an instance of a target type. Checks whether an object is an instance of a target type.
Inputs: Inputs:
inst (mindspore.dtype): Inspected type. inst (mindspore.dtype): Inspected type.
@ -69,34 +71,23 @@ def is_same_type(inst, type_):
return inst == type_ return inst == type_
@constexpr
def error_msg(msg="", format_values=""):
"""
Used to throw exception information.
Inputs:
msg (str): information content.
"""
raise ValueError(msg.format(*format_values))
def slice_expand(input_slices, shape): def slice_expand(input_slices, shape):
""" """
Convert slice to indices. Converts slice to indices.
Inputs: Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice. slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple. shape (tuple): The shape of a sensor is an integer element tuple.
Outputs: Outputs:
(List, List, List), This is expressed as (begins, ends, strides). tuple[list], This is expressed as (begins, ends, strides).
""" """
begin = [] begin = []
end = [] end = []
strides = [] strides = []
index = 0 index = 0
slices = None slices = None
# Slice or Tuple(Slice...) # Slice or tuple(Slice...)
if isinstance(input_slices, Slice): if isinstance(input_slices, Slice):
slices = (input_slices,) slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice): elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
@ -119,14 +110,15 @@ def slice_expand(input_slices, shape):
index += 1 index += 1
return begin, end, strides return begin, end, strides
@constexpr @constexpr
def slice2indices(input_slices, shape): def slice2indices(input_slices, shape):
""" """
Convert slice to indices. Converts slice to indices.
Inputs: Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice. slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple. shape (tuple): The shape of a tensor is an integer element tuple.
Outputs: Outputs:
Tensor, the shape is (n, 1). Tensor, the shape is (n, 1).
@ -145,6 +137,7 @@ def slice2indices(input_slices, shape):
@constexpr @constexpr
def check_indices(indices_size, index): def check_indices(indices_size, index):
"""Checks indices whether is empty."""
if indices_size < 1: if indices_size < 1:
raise ValueError("The tensor's index is unreasonable. index:{}".format(index)) raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size return indices_size
@ -152,6 +145,7 @@ def check_indices(indices_size, index):
@constexpr @constexpr
def check_indices_value_size(indices_size, value_size): def check_indices_value_size(indices_size, value_size):
"""Checks if the sizes are already matched."""
if value_size < 1: if value_size < 1:
raise ValueError("The value assigned to tensor cannot be empty.") raise ValueError("The value assigned to tensor cannot be empty.")
if value_size > 1: if value_size > 1:
@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size):
"The value given to tensor does not match the index size. \ "The value given to tensor does not match the index size. \
value size:{}, indics size:{}".format(value_size, indices_size)) value size:{}, indics size:{}".format(value_size, indices_size))
return value_size return value_size
@constexpr
def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices."""
size = reduce(lambda x, y: x * y, shape)
range_ = np.arange(size).reshape(shape)
value = range_[index]
value = value.reshape(-1, 1)
return Tensor(value, dtype=mstype.int32)
@constexpr
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], Slice):
return True
return False
@constexpr
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], int):
return True
return False

File diff suppressed because it is too large Load Diff

@ -139,7 +139,7 @@ class TensorAssignWithSlice(Cell):
z = a z = a
return z return z
def test_tensor_assign_with_slice(): def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice() net = TensorAssignWithSlice()
net2= TensorAssignWithSlice2() net2= TensorAssignWithSlice2()
@ -148,6 +148,7 @@ def test_tensor_assign_with_slice():
a = np.arange(60).reshape(3,4,5) a = np.arange(60).reshape(3,4,5)
b = Tensor([1]) b = Tensor([1])
Ta = Tensor(a) Ta = Tensor(a)
Ta4d = Tensor(a.reshape(1,3,4,5))
Tb= Tensor([1,3]) Tb= Tensor([1,3])
Tc= Tensor([]) Tc= Tensor([])
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) t = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
@ -185,6 +186,47 @@ def test_tensor_assign_with_slice():
with pytest.raises(ValueError): with pytest.raises(ValueError):
net_e1(Ta, 2) net_e1(Ta, 2)
net = TensorAssignWithInteger()
# Error for A[Number] = scalar/Tensor
# 1. A[Number] = U, U is a Tensor, u.size not match
with pytest.raises(ValueError):
net(Ta, Tb)
with pytest.raises(ValueError):
net(Ta, Tc)
# 2. A[Number] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
# Error for A[(n,m)] = scalar/Tensor
# 1. A[(n,m)] = U, U is a tensor. u.size not match
net = TensorAssignWithTupleInteger()
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)
# 2. A[(n,m)] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
class TensorAssignWithInteger(Cell):
def __init__(self):
super(TensorAssignWithInteger, self).__init__()
def construct(self, a, b):
a[1] = 1
a[0] = b
return a
class TensorAssignWithTupleInteger(Cell):
def __init__(self):
super(TensorAssignWithTupleInteger, self).__init__()
def construct(self, a, b):
a[(1)] = 1
a[(1)] = b
a[(1,1)] = b
a[(1,1)] = 1
return a
class TensorAssignWithBoolTensorIndex(Cell): class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self): def __init__(self):
@ -274,6 +316,14 @@ def test_tensor_assign_bool_index():
net4(Ta, u_scalar) net4(Ta, u_scalar)
test_cases = [ test_cases = [
('TensorAssignWithTupleInteger', {
'block': TensorAssignWithTupleInteger(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithInteger', {
'block': TensorAssignWithInteger(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithSlice', { ('TensorAssignWithSlice', {
'block': TensorAssignWithSlice(), 'block': TensorAssignWithSlice(),
'desc_inputs': [Ta, u_tensor], 'desc_inputs': [Ta, u_tensor],

Loading…
Cancel
Save