Support determine tensor in(not in) a list(tuple)

pull/7438/head
huanghui 4 years ago
parent b4ce0aa933
commit a9e781921a

@ -577,3 +577,12 @@ def tensor_setitem_by_ellipsis_with_tensor(data, index, value):
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result
def tensor_in_sequence(x, y):
"""Assigns whether a sequence contains the given tensor"""
for i in y:
if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype:
if F.equal(x, i).all():
return const_utils.scalar_to_tensor(True)
return const_utils.scalar_to_tensor(False)

@ -39,14 +39,17 @@ TENSOR_GETITEM = "tensor getitem"
SET_ITEM_BY_ONE_TENSOR = 0
SET_ITEM_BY_TUPLE_OF_TENSOR = 1
@constexpr
def raise_value_error(msg):
raise ValueError(msg)
@constexpr
def raise_index_error(msg):
raise IndexError(msg)
@constexpr
def raise_type_error(msg):
raise TypeError(msg)
@ -704,7 +707,7 @@ def get_stride_info_from_slice(data_shape, slice_index):
def get_stride_info_from_integer(data_shape, number):
"""Get stride info from a integer"""
begin_strides = [number]
end_strides = [number+1]
end_strides = [number + 1]
step_strides = [1]
for end in data_shape[1:]:
begin_strides.append(0)
@ -720,7 +723,7 @@ def get_slice_stride(dim_size, index_slice):
stop_default = dim_size
if step < 0:
start_default = -1
stop_default = -(dim_size+1)
stop_default = -(dim_size + 1)
start = start_default if index_slice.start is None else index_slice.start
stop = stop_default if index_slice.stop is None else index_slice.stop
return start, stop, step
@ -775,3 +778,9 @@ def mstype_eq(x, y):
if x == y:
return True
return False
@constexpr
def scalar_to_tensor(x):
"""Convert a scalar to a tensor"""
return Tensor(x)

@ -16,6 +16,7 @@
"""Implementation for internal polymorphism `in` operations."""
from . import _constexpr_utils as const_utils
from . import _compile_utils as compile_utils
from ... import functional as F
from ...composite import base
@ -99,3 +100,33 @@ def _str_in_dict(x, y):
bool, if x in y return true, x not in y return false.
"""
return F.in_dict(x, y)
@in_.register("Tensor", "List")
def _tensor_in_list(x, y):
"""
Determine if a tensor in a list.
Args:
x: Tensor
y: List
Returns:
bool, if x in y return true, x not in y return false.
"""
return compile_utils.tensor_in_sequence(x, y)
@in_.register("Tensor", "Tuple")
def _tensor_in_tuple(x, y):
"""
Determine if a tensor in a tuple.
Args:
x: Tensor
y: Tuple
Returns:
bool, if x in y return true, x not in y return false.
"""
return compile_utils.tensor_in_sequence(x, y)

@ -16,6 +16,7 @@
"""Implementation for internal polymorphism `not in` operations."""
from . import _constexpr_utils as const_utils
from . import _compile_utils as compile_utils
from ... import functional as F
from ...composite import base
@ -99,3 +100,33 @@ def _str_not_in_dict(x, y):
bool, if x not in y return true, x in y return false.
"""
return F.not_in_dict(x, y)
@not_in_.register("Tensor", "List")
def _tensor_not_in_list(x, y):
"""
Determine if a tensor not in a list.
Args:
x: Tensor
y: List
Returns:
bool, if x not in y return true, x in y return false.
"""
return not compile_utils.tensor_in_sequence(x, y)
@not_in_.register("Tensor", "Tuple")
def _tensor_not_in_tuple(x, y):
"""
Determine if a tensor not in a tuple.
Args:
x: Tensor
y: Tuple
Returns:
bool, if x not in y return true, x in y return false.
"""
return not compile_utils.tensor_in_sequence(x, y)

@ -30,7 +30,6 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation(get_all=True)
@ -258,6 +257,34 @@ class AxisListDefaultNet(nn.Cell):
return self.reduce_sum(x)
class TensorInList(nn.Cell):
def __init__(self):
super(TensorInList, self).__init__()
self.t1 = Tensor(1, mstype.float32)
self.t2 = Tensor(2, mstype.float32)
def construct(self, x):
ret = x
list_ = [1, [2, 3], "str", self.t1, self.t2, x]
if x in list_:
ret = x + x
return ret
class TensorNotInList(nn.Cell):
def __init__(self):
super(TensorNotInList, self).__init__()
self.t1 = Tensor(1, mstype.float32)
self.t2 = Tensor(2, mstype.float32)
def construct(self, x):
ret = x
list_ = [self.t2, x]
if self.t1 not in list_:
ret = x + x
return ret
test_case_ops = [
('ListOperate', {
'block': ListOperate(),
@ -275,6 +302,12 @@ test_case_ops = [
('InList', {
'block': InListNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
('TensorInList', {
'block': TensorInList(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
('TensorNotInList', {
'block': TensorNotInList(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
]
test_case_lists = [test_case_ops]

@ -53,7 +53,7 @@ class NestTupleGraphNet(nn.Cell):
class InTupleNet(nn.Cell):
def __init__(self,):
def __init__(self):
super(InTupleNet, self).__init__()
self.tuple_ = (1, 2, 3, 4, 5, "ok")
@ -66,6 +66,34 @@ class InTupleNet(nn.Cell):
return ret
class TensorInTuple(nn.Cell):
def __init__(self):
super(TensorInTuple, self).__init__()
self.t1 = Tensor(1, mstype.float32)
self.t2 = Tensor(2, mstype.float32)
self.tuple_ = (self.t1, self.t2)
def construct(self, x):
ret = x
if self.t1 in self.tuple_:
ret = x + x
return ret
class TensorNotInTuple(nn.Cell):
def __init__(self):
super(TensorNotInTuple, self).__init__()
self.t1 = Tensor(1, mstype.float32)
self.t2 = Tensor(2, mstype.float32)
self.tuple_ = (self.t1, self.t2)
def construct(self, x):
ret = x
if self.t1 not in self.tuple_:
ret = x + x
return ret
test_case_ops = [
('TupleGraph', {
'block': TupleGraphNet(),
@ -75,7 +103,13 @@ test_case_ops = [
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('InTuple', {
'block': InTupleNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]})
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('TensorInTuple', {
'block': TensorInTuple(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('TensorNotInTuple', {
'block': TensorNotInTuple(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
]
test_case_lists = [test_case_ops]

Loading…
Cancel
Save