|
|
|
@ -59,13 +59,15 @@ def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
|
|
|
|
|
|
|
|
|
|
def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
|
|
|
|
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
|
|
|
|
|
data_shape = F.shape(data)
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index)
|
|
|
|
|
int_positions = const_utils.get_pos_of_int_index(indexes_types)
|
|
|
|
|
tuple_index_new = ()
|
|
|
|
|
tuple_len = len(tuple_index)
|
|
|
|
|
for i in range(tuple_len):
|
|
|
|
|
if i in int_positions:
|
|
|
|
|
tuple_index_new += (F.scalar_to_tensor(tuple_index[i], mstype.int32),)
|
|
|
|
|
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + \
|
|
|
|
|
data_shape[i], mstype.int32),)
|
|
|
|
|
else:
|
|
|
|
|
tuple_index_new += (tuple_index[i],)
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_new)
|
|
|
|
@ -77,7 +79,6 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
|
|
|
|
|
tensor_indexes.append(tuple_index_new[i])
|
|
|
|
|
for j in slice_positions:
|
|
|
|
|
slice_indexes.append(tuple_index_new[j])
|
|
|
|
|
data_shape = F.shape(data)
|
|
|
|
|
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
|
|
|
|
|
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
|
|
|
|
|
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
|
|
|
|
|