|
|
|
@ -53,7 +53,7 @@ class NestTupleGraphNet(nn.Cell):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InTupleNet(nn.Cell):
|
|
|
|
|
def __init__(self,):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(InTupleNet, self).__init__()
|
|
|
|
|
self.tuple_ = (1, 2, 3, 4, 5, "ok")
|
|
|
|
|
|
|
|
|
@ -66,6 +66,34 @@ class InTupleNet(nn.Cell):
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorInTuple(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(TensorInTuple, self).__init__()
|
|
|
|
|
self.t1 = Tensor(1, mstype.float32)
|
|
|
|
|
self.t2 = Tensor(2, mstype.float32)
|
|
|
|
|
self.tuple_ = (self.t1, self.t2)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
ret = x
|
|
|
|
|
if self.t1 in self.tuple_:
|
|
|
|
|
ret = x + x
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorNotInTuple(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(TensorNotInTuple, self).__init__()
|
|
|
|
|
self.t1 = Tensor(1, mstype.float32)
|
|
|
|
|
self.t2 = Tensor(2, mstype.float32)
|
|
|
|
|
self.tuple_ = (self.t1, self.t2)
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
ret = x
|
|
|
|
|
if self.t1 not in self.tuple_:
|
|
|
|
|
ret = x + x
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_case_ops = [
|
|
|
|
|
('TupleGraph', {
|
|
|
|
|
'block': TupleGraphNet(),
|
|
|
|
@ -75,7 +103,13 @@ test_case_ops = [
|
|
|
|
|
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
|
|
|
|
|
('InTuple', {
|
|
|
|
|
'block': InTupleNet(),
|
|
|
|
|
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]})
|
|
|
|
|
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
|
|
|
|
|
('TensorInTuple', {
|
|
|
|
|
'block': TensorInTuple(),
|
|
|
|
|
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
|
|
|
|
|
('TensorNotInTuple', {
|
|
|
|
|
'block': TensorNotInTuple(),
|
|
|
|
|
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
test_case_lists = [test_case_ops]
|
|
|
|
|