!595 Tensor assign with int or tuple(int) index

Merge pull request !595 from candanzg/tensor_assign_with_integer
pull/595/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit a02eb240e9

@ -15,6 +15,7 @@
"""constexpr util"""
from functools import reduce
import numpy as np
from ...primitive import constexpr
from ....common.tensor import Tensor
@ -23,26 +24,27 @@ from ...._extends.utils import Slice
@constexpr
def check_equal(param1, param2, msg="{},{}"):
"""Checks whether the two parameters are equal or not."""
if param1 != param2:
raise ValueError(msg.format(param1, param2))
return param1
@constexpr
def check_tensor_setitem_index(index, element_type=None):
"""Check tuple index type of tensor assignment."""
"""Checks tuple index type of tensor assignment."""
if index is None:
raise ValueError("Tensor's index cannot be None.")
# eg. Tensor[Slice] = u
if isinstance(index, Slice):
return True
# eg. Tensor[Tuple] = u
# eg. Tensor[tuple] = u
if isinstance(index, tuple):
if not index:
raise ValueError("Tensor's index cannot be empty.")
# eg. Tensor[Tuple(Slice...)] = u
if not isinstance(index[0], Slice):
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
return True
# eg. Tensor[tuple(Slice...)] = u
if isinstance(index[0], (Slice, int)):
return True
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
# eg. Tensor[Tensor[dtype=bool]] = u
if index == mstype.tensor:
if element_type is None or element_type != mstype.bool_:
@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None):
@constexpr
def is_same_type(inst, type_):
"""
Check whether an object is an instance of a target type.
Checks whether an object is an instance of a target type.
Inputs:
inst (mindspore.dtype): Inspected type.
@ -69,34 +71,23 @@ def is_same_type(inst, type_):
return inst == type_
@constexpr
def error_msg(msg="", format_values=""):
"""
Used to throw exception information.
Inputs:
msg (str): information content.
"""
raise ValueError(msg.format(*format_values))
def slice_expand(input_slices, shape):
"""
Convert slice to indices.
Converts slice to indices.
Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple.
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a sensor is an integer element tuple.
Outputs:
(List, List, List), This is expressed as (begins, ends, strides).
tuple[list], This is expressed as (begins, ends, strides).
"""
begin = []
end = []
strides = []
index = 0
slices = None
# Slice or Tuple(Slice...)
# Slice or tuple(Slice...)
if isinstance(input_slices, Slice):
slices = (input_slices,)
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
@ -119,14 +110,15 @@ def slice_expand(input_slices, shape):
index += 1
return begin, end, strides
@constexpr
def slice2indices(input_slices, shape):
"""
Convert slice to indices.
Converts slice to indices.
Inputs:
slices (List or Tuple(List, ...)): Slice tuple or slice.
shape (Tuple): The shape of a sensor is an integer element tuple.
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
shape (tuple): The shape of a tensor is an integer element tuple.
Outputs:
Tensor, the shape is (n, 1).
@ -145,6 +137,7 @@ def slice2indices(input_slices, shape):
@constexpr
def check_indices(indices_size, index):
"""Checks indices whether is empty."""
if indices_size < 1:
raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
return indices_size
@ -152,6 +145,7 @@ def check_indices(indices_size, index):
@constexpr
def check_indices_value_size(indices_size, value_size):
"""Checks if the sizes are already matched."""
if value_size < 1:
raise ValueError("The value assigned to tensor cannot be empty.")
if value_size > 1:
@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size):
"The value given to tensor does not match the index size. \
value size:{}, indics size:{}".format(value_size, indices_size))
return value_size
@constexpr
def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices."""
size = reduce(lambda x, y: x * y, shape)
range_ = np.arange(size).reshape(shape)
value = range_[index]
value = value.reshape(-1, 1)
return Tensor(value, dtype=mstype.int32)
@constexpr
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], Slice):
return True
return False
@constexpr
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
if not indexs:
raise ValueError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple) and isinstance(indexs[0], int):
return True
return False

File diff suppressed because it is too large Load Diff

@ -139,7 +139,7 @@ class TensorAssignWithSlice(Cell):
z = a
return z
def test_tensor_assign_with_slice():
def test_tensor_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
net = TensorAssignWithSlice()
net2= TensorAssignWithSlice2()
@ -148,6 +148,7 @@ def test_tensor_assign_with_slice():
a = np.arange(60).reshape(3,4,5)
b = Tensor([1])
Ta = Tensor(a)
Ta4d = Tensor(a.reshape(1,3,4,5))
Tb= Tensor([1,3])
Tc= Tensor([])
t = Tensor([1, 2, 3, 4, 5, 6, 7, 8])
@ -185,6 +186,47 @@ def test_tensor_assign_with_slice():
with pytest.raises(ValueError):
net_e1(Ta, 2)
net = TensorAssignWithInteger()
# Error for A[Number] = scalar/Tensor
# 1. A[Number] = U, U is a Tensor, u.size not match
with pytest.raises(ValueError):
net(Ta, Tb)
with pytest.raises(ValueError):
net(Ta, Tc)
# 2. A[Number] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
# Error for A[(n,m)] = scalar/Tensor
# 1. A[(n,m)] = U, U is a tensor. u.size not match
net = TensorAssignWithTupleInteger()
with pytest.raises(ValueError):
net(Ta, Tc)
with pytest.raises(ValueError):
net(Ta, Tb)
# 2. A[(n,m)] = U, the number index error
with pytest.raises(IndexError):
net(Ta4d, b)
class TensorAssignWithInteger(Cell):
def __init__(self):
super(TensorAssignWithInteger, self).__init__()
def construct(self, a, b):
a[1] = 1
a[0] = b
return a
class TensorAssignWithTupleInteger(Cell):
def __init__(self):
super(TensorAssignWithTupleInteger, self).__init__()
def construct(self, a, b):
a[(1)] = 1
a[(1)] = b
a[(1,1)] = b
a[(1,1)] = 1
return a
class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self):
@ -274,6 +316,14 @@ def test_tensor_assign_bool_index():
net4(Ta, u_scalar)
test_cases = [
('TensorAssignWithTupleInteger', {
'block': TensorAssignWithTupleInteger(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithInteger', {
'block': TensorAssignWithInteger(),
'desc_inputs': [Ta, u_tensor],
}),
('TensorAssignWithSlice', {
'block': TensorAssignWithSlice(),
'desc_inputs': [Ta, u_tensor],

Loading…
Cancel
Save