From c94dea6a512eddb6cbe8b591268d82d7b9aa3209 Mon Sep 17 00:00:00 2001 From: zhoufeng <zhoufeng54@huawei.com> Date: Wed, 1 Jul 2020 21:50:33 +0800 Subject: [PATCH] Modify nested while testcase Signed-off-by: zhoufeng <zhoufeng54@huawei.com> --- tests/st/control/test_ascend_control_sink.py | 54 ++++++++++++++++---- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/tests/st/control/test_ascend_control_sink.py b/tests/st/control/test_ascend_control_sink.py index 2c206c9768..b38668cd25 100644 --- a/tests/st/control/test_ascend_control_sink.py +++ b/tests/st/control/test_ascend_control_sink.py @@ -102,16 +102,41 @@ class ControlIfbyIfbyIf(nn.Cell): class ControlMixedWhileIf(nn.Cell): def __init__(self): super().__init__() + self.assign = op.Assign() + self.var = Parameter(initializer(1, (1), mstype.float32), name="var") + + def construct(self, x, y, z, c2, c4): + out = self.assign(self.var, c4) + while x < c2: + y = self.assign(self.var, c4) + while y < c2 and x < c2: + if 2 * y < c2: + y = y + 2 + else: + y = y + 1 + out = out + y + z = self.assign(self.var, c4) + while z < c2: + z = z + 1 + out = out + z + x = x + 1 + out = out + x + while x < 2 * c2: + y = self.assign(self.var, c4) + x = x + 1 + while y < c2: + z = self.assign(self.var, c4) + while z < c2: + z = z + 1 + if x < c2: + y = y - 1 + else: + y = y + 1 + out = out + z + out = out + y + out = out + x + return out - def construct(self, x, y): - y = y + 4 - while x < y: - if 2 * x < y: - x = x + 1 - else: - x = x + 2 - x = x + 3 - return x @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @@ -130,6 +155,7 @@ def test_simple_if(): expect = input2 * 3 * 3 * 2 + input1 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -145,6 +171,7 @@ def test_simple_if_with_assign(): expect = input_data assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -158,6 +185,7 @@ def test_if_in_if(): expect = x + y + 3 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -175,6 +203,7 @@ def test_if_by_if_by_if(): expect = input_data * 3 * 2 * 2 * 2 assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -183,7 +212,10 @@ def test_mixed_while_if(): context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") x = np.array(2).astype(np.int32) y = np.array(14).astype(np.int32) + z = np.array(1).astype(np.int32) + c2 = Tensor([14], mstype.int32) + c4 = Tensor([0], mstype.int32) net = ControlMixedWhileIf() - output = net(Tensor(x), Tensor(y)) - expect = np.array(22).astype(np.int32) + output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4) + expect = np.array(3318).astype(np.int32) assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)