From c94dea6a512eddb6cbe8b591268d82d7b9aa3209 Mon Sep 17 00:00:00 2001
From: zhoufeng <zhoufeng54@huawei.com>
Date: Wed, 1 Jul 2020 21:50:33 +0800
Subject: [PATCH] Modify nested while testcase

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
---
 tests/st/control/test_ascend_control_sink.py | 54 ++++++++++++++++----
 1 file changed, 43 insertions(+), 11 deletions(-)

diff --git a/tests/st/control/test_ascend_control_sink.py b/tests/st/control/test_ascend_control_sink.py
index 2c206c9768..b38668cd25 100644
--- a/tests/st/control/test_ascend_control_sink.py
+++ b/tests/st/control/test_ascend_control_sink.py
@@ -102,16 +102,41 @@ class ControlIfbyIfbyIf(nn.Cell):
 class ControlMixedWhileIf(nn.Cell):
     def __init__(self):
         super().__init__()
+        self.assign = op.Assign()
+        self.var = Parameter(initializer(1, (1), mstype.float32), name="var")
+
+    def construct(self, x, y, z, c2, c4):
+        out = self.assign(self.var, c4)
+        while x < c2:
+            y = self.assign(self.var, c4)
+            while y < c2 and x < c2:
+                if 2 * y < c2:
+                    y = y + 2
+                else:
+                    y = y + 1
+            out = out + y
+            z = self.assign(self.var, c4)
+            while z < c2:
+                z = z + 1
+            out = out + z
+            x = x + 1
+        out = out + x
+        while x < 2 * c2:
+            y = self.assign(self.var, c4)
+            x = x + 1
+            while y < c2:
+                z = self.assign(self.var, c4)
+                while z < c2:
+                    z = z + 1
+                if x < c2:
+                    y = y - 1
+                else:
+                    y = y + 1
+                out = out + z
+            out = out + y
+        out = out + x
+        return out
 
-    def construct(self, x, y):
-        y = y + 4
-        while x < y:
-            if 2 * x < y:
-                x = x + 1
-            else:
-                x = x + 2
-        x = x + 3
-        return x
 
 @pytest.mark.level0
 @pytest.mark.platform_arm_ascend_training
@@ -130,6 +155,7 @@ def test_simple_if():
     expect = input2 * 3 * 3 * 2 + input1
     assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
 
+
 @pytest.mark.level0
 @pytest.mark.platform_arm_ascend_training
 @pytest.mark.platform_x86_ascend_training
@@ -145,6 +171,7 @@ def test_simple_if_with_assign():
     expect = input_data
     assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
 
+
 @pytest.mark.level0
 @pytest.mark.platform_arm_ascend_training
 @pytest.mark.platform_x86_ascend_training
@@ -158,6 +185,7 @@ def test_if_in_if():
     expect = x + y + 3
     assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
 
+
 @pytest.mark.level0
 @pytest.mark.platform_arm_ascend_training
 @pytest.mark.platform_x86_ascend_training
@@ -175,6 +203,7 @@ def test_if_by_if_by_if():
     expect = input_data * 3 * 2 * 2 * 2
     assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
 
+
 @pytest.mark.level0
 @pytest.mark.platform_arm_ascend_training
 @pytest.mark.platform_x86_ascend_training
@@ -183,7 +212,10 @@ def test_mixed_while_if():
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
     x = np.array(2).astype(np.int32)
     y = np.array(14).astype(np.int32)
+    z = np.array(1).astype(np.int32)
+    c2 = Tensor([14], mstype.int32)
+    c4 = Tensor([0], mstype.int32)
     net = ControlMixedWhileIf()
-    output = net(Tensor(x), Tensor(y))
-    expect = np.array(22).astype(np.int32)
+    output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
+    expect = np.array(3318).astype(np.int32)
     assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)