diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 9d5ef43d8a..13c5932184 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -577,3 +577,12 @@ def tensor_setitem_by_ellipsis_with_tensor(data, index, value): param2 = F.cast(value, data_dtype) result = F.tensor_mul(param1, param2) return result + + +def tensor_in_sequence(x, y): + """Assigns whether a sequence contains the given tensor""" + for i in y: + if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype: + if F.equal(x, i).all(): + return const_utils.scalar_to_tensor(True) + return const_utils.scalar_to_tensor(False) diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 9489991abf..6c666af7d9 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -39,14 +39,17 @@ TENSOR_GETITEM = "tensor getitem" SET_ITEM_BY_ONE_TENSOR = 0 SET_ITEM_BY_TUPLE_OF_TENSOR = 1 + @constexpr def raise_value_error(msg): raise ValueError(msg) + @constexpr def raise_index_error(msg): raise IndexError(msg) + @constexpr def raise_type_error(msg): raise TypeError(msg) @@ -704,7 +707,7 @@ def get_stride_info_from_slice(data_shape, slice_index): def get_stride_info_from_integer(data_shape, number): """Get stride info from a integer""" begin_strides = [number] - end_strides = [number+1] + end_strides = [number + 1] step_strides = [1] for end in data_shape[1:]: begin_strides.append(0) @@ -720,7 +723,7 @@ def get_slice_stride(dim_size, index_slice): stop_default = dim_size if step < 0: start_default = -1 - stop_default = -(dim_size+1) + stop_default = -(dim_size + 1) start = start_default if index_slice.start is None else index_slice.start stop = stop_default if index_slice.stop is None else index_slice.stop return start, stop, step @@ -775,3 +778,9 @@ def mstype_eq(x, y): if x == y: return True return False + + +@constexpr +def scalar_to_tensor(x): + """Convert a scalar to a tensor""" + return Tensor(x) diff --git a/mindspore/ops/composite/multitype_ops/in_impl.py b/mindspore/ops/composite/multitype_ops/in_impl.py index 22264e41cd..89776fe0c7 100644 --- a/mindspore/ops/composite/multitype_ops/in_impl.py +++ b/mindspore/ops/composite/multitype_ops/in_impl.py @@ -16,6 +16,7 @@ """Implementation for internal polymorphism `in` operations.""" from . import _constexpr_utils as const_utils +from . import _compile_utils as compile_utils from ... import functional as F from ...composite import base @@ -99,3 +100,33 @@ def _str_in_dict(x, y): bool, if x in y return true, x not in y return false. """ return F.in_dict(x, y) + + +@in_.register("Tensor", "List") +def _tensor_in_list(x, y): + """ + Determine if a tensor in a list. + + Args: + x: Tensor + y: List + + Returns: + bool, if x in y return true, x not in y return false. + """ + return compile_utils.tensor_in_sequence(x, y) + + +@in_.register("Tensor", "Tuple") +def _tensor_in_tuple(x, y): + """ + Determine if a tensor in a tuple. + + Args: + x: Tensor + y: Tuple + + Returns: + bool, if x in y return true, x not in y return false. + """ + return compile_utils.tensor_in_sequence(x, y) diff --git a/mindspore/ops/composite/multitype_ops/not_in_impl.py b/mindspore/ops/composite/multitype_ops/not_in_impl.py index c5fd494bc9..4499ccf095 100644 --- a/mindspore/ops/composite/multitype_ops/not_in_impl.py +++ b/mindspore/ops/composite/multitype_ops/not_in_impl.py @@ -16,6 +16,7 @@ """Implementation for internal polymorphism `not in` operations.""" from . import _constexpr_utils as const_utils +from . import _compile_utils as compile_utils from ... import functional as F from ...composite import base @@ -99,3 +100,33 @@ def _str_not_in_dict(x, y): bool, if x not in y return true, x in y return false. """ return F.not_in_dict(x, y) + + +@not_in_.register("Tensor", "List") +def _tensor_not_in_list(x, y): + """ + Determine if a tensor not in a list. + + Args: + x: Tensor + y: List + + Returns: + bool, if x not in y return true, x in y return false. + """ + return not compile_utils.tensor_in_sequence(x, y) + + +@not_in_.register("Tensor", "Tuple") +def _tensor_not_in_tuple(x, y): + """ + Determine if a tensor not in a tuple. + + Args: + x: Tensor + y: Tuple + + Returns: + bool, if x not in y return true, x in y return false. + """ + return not compile_utils.tensor_in_sequence(x, y) diff --git a/tests/ut/python/dtype/test_list.py b/tests/ut/python/dtype/test_list.py index 13460e03ba..3ec19a9b1f 100644 --- a/tests/ut/python/dtype/test_list.py +++ b/tests/ut/python/dtype/test_list.py @@ -30,7 +30,6 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \ context.set_context(mode=context.GRAPH_MODE) - grad_all = C.GradOperation(get_all=True) @@ -258,6 +257,34 @@ class AxisListDefaultNet(nn.Cell): return self.reduce_sum(x) +class TensorInList(nn.Cell): + def __init__(self): + super(TensorInList, self).__init__() + self.t1 = Tensor(1, mstype.float32) + self.t2 = Tensor(2, mstype.float32) + + def construct(self, x): + ret = x + list_ = [1, [2, 3], "str", self.t1, self.t2, x] + if x in list_: + ret = x + x + return ret + + +class TensorNotInList(nn.Cell): + def __init__(self): + super(TensorNotInList, self).__init__() + self.t1 = Tensor(1, mstype.float32) + self.t2 = Tensor(2, mstype.float32) + + def construct(self, x): + ret = x + list_ = [self.t2, x] + if self.t1 not in list_: + ret = x + x + return ret + + test_case_ops = [ ('ListOperate', { 'block': ListOperate(), @@ -275,6 +302,12 @@ test_case_ops = [ ('InList', { 'block': InListNet(), 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), + ('TensorInList', { + 'block': TensorInList(), + 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), + ('TensorNotInList', { + 'block': TensorNotInList(), + 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), ] test_case_lists = [test_case_ops] diff --git a/tests/ut/python/dtype/test_tuple.py b/tests/ut/python/dtype/test_tuple.py index 2f4ef7e508..3b0a1693d1 100644 --- a/tests/ut/python/dtype/test_tuple.py +++ b/tests/ut/python/dtype/test_tuple.py @@ -53,7 +53,7 @@ class NestTupleGraphNet(nn.Cell): class InTupleNet(nn.Cell): - def __init__(self,): + def __init__(self): super(InTupleNet, self).__init__() self.tuple_ = (1, 2, 3, 4, 5, "ok") @@ -66,6 +66,34 @@ class InTupleNet(nn.Cell): return ret +class TensorInTuple(nn.Cell): + def __init__(self): + super(TensorInTuple, self).__init__() + self.t1 = Tensor(1, mstype.float32) + self.t2 = Tensor(2, mstype.float32) + self.tuple_ = (self.t1, self.t2) + + def construct(self, x): + ret = x + if self.t1 in self.tuple_: + ret = x + x + return ret + + +class TensorNotInTuple(nn.Cell): + def __init__(self): + super(TensorNotInTuple, self).__init__() + self.t1 = Tensor(1, mstype.float32) + self.t2 = Tensor(2, mstype.float32) + self.tuple_ = (self.t1, self.t2) + + def construct(self, x): + ret = x + if self.t1 not in self.tuple_: + ret = x + x + return ret + + test_case_ops = [ ('TupleGraph', { 'block': TupleGraphNet(), @@ -75,7 +103,13 @@ test_case_ops = [ 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), ('InTuple', { 'block': InTupleNet(), - 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}) + 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), + ('TensorInTuple', { + 'block': TensorInTuple(), + 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), + ('TensorNotInTuple', { + 'block': TensorNotInTuple(), + 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), ] test_case_lists = [test_case_ops]