diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index f46e2cbe07..901ae00968 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -1185,8 +1185,19 @@ class ExecuteOrderGenerator { MS_EXCEPTION_IF_NULL(target); auto para = param_write_times.find(target); if (para != param_write_times.end() && para->second == 1) { - // If target only write once, replace target with source and erase assign node. + // Check source of the Assign. auto &source = node->inputs().at(kAssignSourceIndex); + MS_EXCEPTION_IF_NULL(source); + if (source->isa()) { + auto it = param_write_times.find(source); + if (it != param_write_times.end() && it->second > 0) { + // Skip if Assign source is a parameter and be written in other place. + ++iter; + continue; + } + } + // If target only write once, and source not be written, + // replace target with source and erase the Assign node. auto kg = target->func_graph()->cast(); MS_EXCEPTION_IF_NULL(kg); kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source)); diff --git a/tests/st/auto_monad/test_auto_monad.py b/tests/st/auto_monad/test_auto_monad.py index 2622b84f96..82ebece1bd 100644 --- a/tests/st/auto_monad/test_auto_monad.py +++ b/tests/st/auto_monad/test_auto_monad.py @@ -1429,6 +1429,33 @@ def test_if_cast(): np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy()) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_while_forward(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + + def construct(self, idx, end, x): + while idx < end: + part = x[idx, :, :] + max_num = self.max(part) + x[idx, :, 0:2] = max_num + idx = idx + 1 + return x + + net = MyWhileNet() + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + output = net(idx, end, x) + expect = np.array([[[3, 3], [3, 3]], [[7, 7], [7, 7]]], dtype=np.int32) + assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001) + + @pytest.mark.skip(reason="not supported yet") def test_multi_add_assign(): class Net(Cell):