From 6d41c00907da4406fa81c2c4595aa582be50299c Mon Sep 17 00:00:00 2001 From: chenfei Date: Mon, 14 Dec 2020 16:42:44 +0800 Subject: [PATCH] use select implement tensor in list --- .../composite/multitype_ops/_compile_utils.py | 6 +- tests/st/control/test_tensor_in_list.py | 58 +++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 tests/st/control/test_tensor_in_list.py diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index aea85bb8aa..a8e44fff55 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -618,8 +618,8 @@ def tensor_setitem_by_ellipsis_with_tensor(data, index, value): def tensor_in_sequence(x, y): """Assigns whether a sequence contains the given tensor""" + result = const_utils.scalar_to_tensor(False) for i in y: if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype: - if F.equal(x, i).all(): - return const_utils.scalar_to_tensor(True) - return const_utils.scalar_to_tensor(False) + result = F.logical_or(F.equal(x, i).all(), result) + return result diff --git a/tests/st/control/test_tensor_in_list.py b/tests/st/control/test_tensor_in_list.py new file mode 100644 index 0000000000..e1d423e719 --- /dev/null +++ b/tests/st/control/test_tensor_in_list.py @@ -0,0 +1,58 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_tensor_in_list """ +import pytest +from mindspore import nn, Tensor, context +from mindspore import dtype as mstype + + +def setup_module(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.list = [Tensor([1], mstype.int32), Tensor([2], mstype.int32), Tensor([3], mstype.int32)] + + def construct(self, c): + if c in self.list: + out = c + c + else: + out = c + 0 + return out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_tensor_in_list(): + net = Net() + output = net(Tensor([1], mstype.int32)) + expect = Tensor([2], mstype.int32) + assert output == expect + + output = net(Tensor([2], mstype.int32)) + expect = Tensor([4], mstype.int32) + assert output == expect + + output = net(Tensor([3], mstype.int32)) + expect = Tensor([6], mstype.int32) + assert output == expect + + output = net(Tensor([4], mstype.int32)) + expect = Tensor([4], mstype.int32) + assert output == expect