|
|
|
@ -73,6 +73,13 @@ def make_empty_slice():
|
|
|
|
|
return slice(None, None, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def make_tensor(data, data_type, data_shape=None):
|
|
|
|
|
if data_shape:
|
|
|
|
|
return Tensor(np.zeros(data_shape), data_type)
|
|
|
|
|
return Tensor(data, data_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
|
|
|
|
|
"""Checks the shape and size of the sensor and value."""
|
|
|
|
@ -158,6 +165,36 @@ def check_indexes_types_valid(dtypes, target_type, op_name):
|
|
|
|
|
f"but got {dtype}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def get_pos_of_indexes_types(indexes_types, op_name):
|
|
|
|
|
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
|
|
|
|
|
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
|
|
|
|
|
sequence_positions = [], [], [], [], [], [], []
|
|
|
|
|
for i, index_type in enumerate(indexes_types):
|
|
|
|
|
if isinstance(index_type, mstype.slice_type):
|
|
|
|
|
slice_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.ellipsis_type):
|
|
|
|
|
ellipsis_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.none_type):
|
|
|
|
|
none_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.Int):
|
|
|
|
|
int_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.bool_type):
|
|
|
|
|
bool_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.tensor_type):
|
|
|
|
|
tensor_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, (list, tuple)):
|
|
|
|
|
sequence_positions.append(i)
|
|
|
|
|
else:
|
|
|
|
|
raise IndexError(f"For '{op_name}', the index elements only support "
|
|
|
|
|
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
|
|
|
|
|
if len(ellipsis_positions) > 1:
|
|
|
|
|
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")
|
|
|
|
|
|
|
|
|
|
return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
|
|
|
|
|
tensor_positions, sequence_positions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def slice_expand(input_slices, shape):
|
|
|
|
|
"""
|
|
|
|
|
Converts slice to indices.
|
|
|
|
@ -293,13 +330,6 @@ def tuple_element_is_int(indexs):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def tuple_index_tensor_cnt(types, op_name):
|
|
|
|
|
"""count the tensor type of types which contains the tuple elements' type."""
|
|
|
|
|
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
|
|
|
|
|
return ALL_TENSOR if tensor_cnt == len(types) else NO_TENSOR if tensor_cnt == 0 else CONTAIN_TENSOR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ constexpr
|
|
|
|
|
def tuple_index_int_cnt(types, op_name):
|
|
|
|
|
"""count the int type of types which contains the tuple elements' type."""
|
|
|
|
@ -344,6 +374,8 @@ def check_value_elements(data_dtype, types):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
|
|
|
|
|
|
|
|
|
|
# TODO to del
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ constexpr
|
|
|
|
|
def get_index_tensor_dtype(dtype):
|
|
|
|
@ -356,6 +388,7 @@ def get_index_tensor_dtype(dtype):
|
|
|
|
|
f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO to del
|
|
|
|
|
@ constexpr
|
|
|
|
|
def check_index_tensors_dtype(indexes_types, op_name):
|
|
|
|
|
"""Check a tuple of tensor data type."""
|
|
|
|
@ -366,6 +399,7 @@ def check_index_tensors_dtype(indexes_types, op_name):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO to del
|
|
|
|
|
@ constexpr
|
|
|
|
|
def check_index_tensor_dtype(index_type, op_name):
|
|
|
|
|
"""Check a tensor data type."""
|
|
|
|
@ -375,6 +409,7 @@ def check_index_tensor_dtype(index_type, op_name):
|
|
|
|
|
f"but got {index_type}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO to del
|
|
|
|
|
@ constexpr
|
|
|
|
|
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
|
|
|
|
|
"""Check tensors data type same."""
|
|
|
|
@ -645,36 +680,6 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
|
|
|
|
|
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def get_pos_of_indexes_types(indexes_types, op_name):
|
|
|
|
|
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
|
|
|
|
|
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
|
|
|
|
|
sequence_positions = [], [], [], [], [], [], []
|
|
|
|
|
for i, index_type in enumerate(indexes_types):
|
|
|
|
|
if isinstance(index_type, mstype.slice_type):
|
|
|
|
|
slice_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.ellipsis_type):
|
|
|
|
|
ellipsis_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.none_type):
|
|
|
|
|
none_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.Int):
|
|
|
|
|
int_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.bool_type):
|
|
|
|
|
bool_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, mstype.tensor_type):
|
|
|
|
|
tensor_positions.append(i)
|
|
|
|
|
elif isinstance(index_type, (list, tuple)):
|
|
|
|
|
sequence_positions.append(i)
|
|
|
|
|
else:
|
|
|
|
|
raise IndexError(f"For '{op_name}', the index elements only support "
|
|
|
|
|
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
|
|
|
|
|
if len(ellipsis_positions) > 1:
|
|
|
|
|
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")
|
|
|
|
|
|
|
|
|
|
return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
|
|
|
|
|
tensor_positions, sequence_positions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ constexpr
|
|
|
|
|
def scalar_in_sequence(x, y):
|
|
|
|
|
"""Determine whether the scalar in the sequence."""
|
|
|
|
|