|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
|
|
|
|
from ...primitive import constexpr
|
|
|
|
|
from ....common.tensor import Tensor
|
|
|
|
|
from ....common import dtype as mstype
|
|
|
|
|
from ...._extends.utils import Slice
|
|
|
|
|
from ...._extends.utils import Slice, Ellipsis_
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_equal(param1, param2, msg="{},{}"):
|
|
|
|
@ -29,31 +29,40 @@ def check_equal(param1, param2, msg="{},{}"):
|
|
|
|
|
raise ValueError(msg.format(param1, param2))
|
|
|
|
|
return param1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
|
|
|
|
"""Checks the shape and size of the sensor and value."""
|
|
|
|
|
if data_shape == value_shape or data_size == value_size or value_size == 1:
|
|
|
|
|
return True
|
|
|
|
|
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(value_shape, data_shape))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_tensor_setitem_index(index, element_type=None):
|
|
|
|
|
"""Checks tuple index type of tensor assignment."""
|
|
|
|
|
if index is None:
|
|
|
|
|
raise ValueError("Tensor's index cannot be None.")
|
|
|
|
|
raise IndexError("Tensor's index cannot be None.")
|
|
|
|
|
# eg. Tensor[Slice] = u
|
|
|
|
|
if isinstance(index, Slice):
|
|
|
|
|
return True
|
|
|
|
|
# eg. Tensor[tuple] = u
|
|
|
|
|
if isinstance(index, tuple):
|
|
|
|
|
if not index:
|
|
|
|
|
raise ValueError("Tensor's index cannot be empty.")
|
|
|
|
|
raise IndexError("Tensor's index cannot be empty.")
|
|
|
|
|
# eg. Tensor[tuple(Slice...)] = u
|
|
|
|
|
if isinstance(index[0], (Slice, int)):
|
|
|
|
|
if isinstance(index[0], (Slice, Ellipsis_, int)):
|
|
|
|
|
return True
|
|
|
|
|
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
|
|
|
|
raise IndexError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
|
|
|
|
# eg. Tensor[Tensor[dtype=bool]] = u
|
|
|
|
|
if index == mstype.tensor:
|
|
|
|
|
if element_type is None or element_type != mstype.bool_:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The index of tensor should be a bool type tensor. \
|
|
|
|
|
{} type is not supported yet.".format(element_type))
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The index of tensor should be a bool type tensor. "
|
|
|
|
|
"{} type is not supported yet.".format(element_type))
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
raise ValueError("Index of type '{}' is not supported yet.".format(type(index)))
|
|
|
|
|
raise IndexError("Index of type '{}' is not supported yet.".format(type(index)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
@ -90,10 +99,18 @@ def slice_expand(input_slices, shape):
|
|
|
|
|
# Slice or tuple(Slice...)
|
|
|
|
|
if isinstance(input_slices, Slice):
|
|
|
|
|
slices = (input_slices,)
|
|
|
|
|
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
|
|
|
|
|
slices = input_slices
|
|
|
|
|
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], (Slice, Ellipsis_)):
|
|
|
|
|
is_have_ellipsis = False
|
|
|
|
|
for _, element in enumerate(input_slices):
|
|
|
|
|
if isinstance(element, Ellipsis_):
|
|
|
|
|
is_have_ellipsis = True
|
|
|
|
|
break
|
|
|
|
|
if is_have_ellipsis:
|
|
|
|
|
slices = ellipsis2slice(input_slices, shape)
|
|
|
|
|
else:
|
|
|
|
|
slices = input_slices
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Tensor's index type is not supported yet.")
|
|
|
|
|
raise IndexError("Tensor's index type is not supported yet.")
|
|
|
|
|
|
|
|
|
|
for s in slices:
|
|
|
|
|
start = 0 if (s.start is None) else s.start
|
|
|
|
@ -111,6 +128,26 @@ def slice_expand(input_slices, shape):
|
|
|
|
|
return begin, end, strides
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ellipsis2slice(input_, shape):
|
|
|
|
|
"""Converts ellipsis to slice."""
|
|
|
|
|
input_slice = input_
|
|
|
|
|
result = []
|
|
|
|
|
if isinstance(input_, Ellipsis_):
|
|
|
|
|
input_slice = (input_,)
|
|
|
|
|
ell_count = 0
|
|
|
|
|
for _, element in enumerate(input_slice):
|
|
|
|
|
if not isinstance(element, Ellipsis_):
|
|
|
|
|
result.append(element)
|
|
|
|
|
continue
|
|
|
|
|
ell_count += 1
|
|
|
|
|
if ell_count > 1:
|
|
|
|
|
raise IndexError("There cannot be more than one ellisis (...) in the index of the tensor, "
|
|
|
|
|
"but it is currently {}".format(input_slice))
|
|
|
|
|
for _ in range(len(shape) - len(input_slice) + 1):
|
|
|
|
|
result.append(Slice(None, None, None))
|
|
|
|
|
return tuple(result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def slice2indices(input_slices, shape):
|
|
|
|
|
"""
|
|
|
|
@ -139,7 +176,7 @@ def slice2indices(input_slices, shape):
|
|
|
|
|
def check_indices(indices_size, index):
|
|
|
|
|
"""Checks indices whether is empty."""
|
|
|
|
|
if indices_size < 1:
|
|
|
|
|
raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
|
|
|
|
|
raise IndexError("The tensor's index is unreasonable. index:{}".format(index))
|
|
|
|
|
return indices_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -151,8 +188,8 @@ def check_indices_value_size(indices_size, value_size):
|
|
|
|
|
if value_size > 1:
|
|
|
|
|
if value_size != indices_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The value given to tensor does not match the index size. \
|
|
|
|
|
value size:{}, indics size:{}".format(value_size, indices_size))
|
|
|
|
|
"The value given to tensor does not match the index size,"
|
|
|
|
|
" value size:{}, indics size:{}".format(value_size, indices_size))
|
|
|
|
|
return value_size
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
@ -168,8 +205,11 @@ def integer_to_indices(index, shape):
|
|
|
|
|
def tuple_element_is_slice(indexs):
|
|
|
|
|
"""Judges tuple element type."""
|
|
|
|
|
if not indexs:
|
|
|
|
|
raise ValueError("Tensor's index cannot be empty.")
|
|
|
|
|
if isinstance(indexs, tuple) and isinstance(indexs[0], Slice):
|
|
|
|
|
raise IndexError("Tensor's index cannot be empty.")
|
|
|
|
|
if isinstance(indexs, tuple):
|
|
|
|
|
for _, ele in enumerate(indexs):
|
|
|
|
|
if not isinstance(ele, Slice):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
@ -177,7 +217,10 @@ def tuple_element_is_slice(indexs):
|
|
|
|
|
def tuple_element_is_int(indexs):
|
|
|
|
|
"""Judges tuple element type."""
|
|
|
|
|
if not indexs:
|
|
|
|
|
raise ValueError("Tensor's index cannot be empty.")
|
|
|
|
|
if isinstance(indexs, tuple) and isinstance(indexs[0], int):
|
|
|
|
|
raise IndexError("Tensor's index cannot be empty.")
|
|
|
|
|
if isinstance(indexs, tuple):
|
|
|
|
|
for _, ele in enumerate(indexs):
|
|
|
|
|
if not isinstance(ele, int):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|