diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 86100818d1..d6ce3ac6f6 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -267,6 +267,8 @@ def _tensor_index_by_tuple_slice(data, t): def tensor_index_by_tuple(data, tuple_index): """Tensor getitem by tuple of various types""" + if len(tuple_index) == 1: + return data[tuple_index[0]] indexes_types = hyper_map(F.typeof, tuple_index) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) if index_elements_type == const_utils.NO_TENSOR: @@ -430,6 +432,9 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value): def tensor_setitem_by_tuple_with_number(data, tuple_index, value): """Assigns the tensor by tuple with number value.""" + if len(tuple_index) == 1: + data[tuple_index[0]] = value + return data indexes_types = hyper_map(F.typeof, tuple_index) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) @@ -489,6 +494,9 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): """Assigns the tensor by tuple with tensor value.""" + if len(tuple_index) == 1: + data[tuple_index[0]] = value + return data indexes_types = hyper_map(F.typeof, tuple_index) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) @@ -509,6 +517,9 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): """Assigns the tensor by tuple with tuple of value.""" + if len(tuple_index) == 1: + data[tuple_index[0]] = value + return data indexes_types = hyper_map(F.typeof, tuple_index) index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM)