|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
|
|
|
|
|
"""constexpr util"""
|
|
|
|
|
|
|
|
|
|
from functools import reduce
|
|
|
|
|
import numpy as np
|
|
|
|
|
from ...primitive import constexpr
|
|
|
|
|
from ....common.tensor import Tensor
|
|
|
|
@ -23,26 +24,27 @@ from ...._extends.utils import Slice
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_equal(param1, param2, msg="{},{}"):
|
|
|
|
|
"""Checks whether the two parameters are equal or not."""
|
|
|
|
|
if param1 != param2:
|
|
|
|
|
raise ValueError(msg.format(param1, param2))
|
|
|
|
|
return param1
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_tensor_setitem_index(index, element_type=None):
|
|
|
|
|
"""Check tuple index type of tensor assignment."""
|
|
|
|
|
"""Checks tuple index type of tensor assignment."""
|
|
|
|
|
if index is None:
|
|
|
|
|
raise ValueError("Tensor's index cannot be None.")
|
|
|
|
|
# eg. Tensor[Slice] = u
|
|
|
|
|
if isinstance(index, Slice):
|
|
|
|
|
return True
|
|
|
|
|
# eg. Tensor[Tuple] = u
|
|
|
|
|
# eg. Tensor[tuple] = u
|
|
|
|
|
if isinstance(index, tuple):
|
|
|
|
|
if not index:
|
|
|
|
|
raise ValueError("Tensor's index cannot be empty.")
|
|
|
|
|
# eg. Tensor[Tuple(Slice...)] = u
|
|
|
|
|
if not isinstance(index[0], Slice):
|
|
|
|
|
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
|
|
|
|
return True
|
|
|
|
|
# eg. Tensor[tuple(Slice...)] = u
|
|
|
|
|
if isinstance(index[0], (Slice, int)):
|
|
|
|
|
return True
|
|
|
|
|
raise ValueError("Index of type '{}' is not supported yet.".format(type(index[0])))
|
|
|
|
|
# eg. Tensor[Tensor[dtype=bool]] = u
|
|
|
|
|
if index == mstype.tensor:
|
|
|
|
|
if element_type is None or element_type != mstype.bool_:
|
|
|
|
@ -57,7 +59,7 @@ def check_tensor_setitem_index(index, element_type=None):
|
|
|
|
|
@constexpr
|
|
|
|
|
def is_same_type(inst, type_):
|
|
|
|
|
"""
|
|
|
|
|
Check whether an object is an instance of a target type.
|
|
|
|
|
Checks whether an object is an instance of a target type.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
inst (mindspore.dtype): Inspected type.
|
|
|
|
@ -69,34 +71,23 @@ def is_same_type(inst, type_):
|
|
|
|
|
return inst == type_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def error_msg(msg="", format_values=""):
|
|
|
|
|
"""
|
|
|
|
|
Used to throw exception information.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
msg (str): information content.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
raise ValueError(msg.format(*format_values))
|
|
|
|
|
|
|
|
|
|
def slice_expand(input_slices, shape):
|
|
|
|
|
"""
|
|
|
|
|
Convert slice to indices.
|
|
|
|
|
Converts slice to indices.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
slices (List or Tuple(List, ...)): Slice tuple or slice.
|
|
|
|
|
shape (Tuple): The shape of a sensor is an integer element tuple.
|
|
|
|
|
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
|
|
|
|
shape (tuple): The shape of a sensor is an integer element tuple.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
(List, List, List), This is expressed as (begins, ends, strides).
|
|
|
|
|
tuple[list], This is expressed as (begins, ends, strides).
|
|
|
|
|
"""
|
|
|
|
|
begin = []
|
|
|
|
|
end = []
|
|
|
|
|
strides = []
|
|
|
|
|
index = 0
|
|
|
|
|
slices = None
|
|
|
|
|
# Slice or Tuple(Slice...)
|
|
|
|
|
# Slice or tuple(Slice...)
|
|
|
|
|
if isinstance(input_slices, Slice):
|
|
|
|
|
slices = (input_slices,)
|
|
|
|
|
elif isinstance(input_slices, (tuple, list)) and input_slices and isinstance(input_slices[0], Slice):
|
|
|
|
@ -119,14 +110,15 @@ def slice_expand(input_slices, shape):
|
|
|
|
|
index += 1
|
|
|
|
|
return begin, end, strides
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def slice2indices(input_slices, shape):
|
|
|
|
|
"""
|
|
|
|
|
Convert slice to indices.
|
|
|
|
|
Converts slice to indices.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
slices (List or Tuple(List, ...)): Slice tuple or slice.
|
|
|
|
|
shape (Tuple): The shape of a sensor is an integer element tuple.
|
|
|
|
|
slices (Union[Slice, tuple[Slice]]): Slice tuple or slice.
|
|
|
|
|
shape (tuple): The shape of a tensor is an integer element tuple.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, the shape is (n, 1).
|
|
|
|
@ -145,6 +137,7 @@ def slice2indices(input_slices, shape):
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_indices(indices_size, index):
|
|
|
|
|
"""Checks indices whether is empty."""
|
|
|
|
|
if indices_size < 1:
|
|
|
|
|
raise ValueError("The tensor's index is unreasonable. index:{}".format(index))
|
|
|
|
|
return indices_size
|
|
|
|
@ -152,6 +145,7 @@ def check_indices(indices_size, index):
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_indices_value_size(indices_size, value_size):
|
|
|
|
|
"""Checks if the sizes are already matched."""
|
|
|
|
|
if value_size < 1:
|
|
|
|
|
raise ValueError("The value assigned to tensor cannot be empty.")
|
|
|
|
|
if value_size > 1:
|
|
|
|
@ -160,3 +154,30 @@ def check_indices_value_size(indices_size, value_size):
|
|
|
|
|
"The value given to tensor does not match the index size. \
|
|
|
|
|
value size:{}, indics size:{}".format(value_size, indices_size))
|
|
|
|
|
return value_size
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def integer_to_indices(index, shape):
|
|
|
|
|
"""Converts int or tuple[int] to indices."""
|
|
|
|
|
size = reduce(lambda x, y: x * y, shape)
|
|
|
|
|
range_ = np.arange(size).reshape(shape)
|
|
|
|
|
value = range_[index]
|
|
|
|
|
value = value.reshape(-1, 1)
|
|
|
|
|
return Tensor(value, dtype=mstype.int32)
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def tuple_element_is_slice(indexs):
|
|
|
|
|
"""Judges tuple element type."""
|
|
|
|
|
if not indexs:
|
|
|
|
|
raise ValueError("Tensor's index cannot be empty.")
|
|
|
|
|
if isinstance(indexs, tuple) and isinstance(indexs[0], Slice):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def tuple_element_is_int(indexs):
|
|
|
|
|
"""Judges tuple element type."""
|
|
|
|
|
if not indexs:
|
|
|
|
|
raise ValueError("Tensor's index cannot be empty.")
|
|
|
|
|
if isinstance(indexs, tuple) and isinstance(indexs[0], int):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|