From 87d7de017b04069052f96a244f29910a8be3962b Mon Sep 17 00:00:00 2001 From: liangzelang Date: Mon, 29 Mar 2021 09:30:42 +0800 Subject: [PATCH] update ci testcases --- tests/st/control/test_cont_grad.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/st/control/test_cont_grad.py b/tests/st/control/test_cont_grad.py index 00bf7534c7..ae87f6e317 100644 --- a/tests/st/control/test_cont_grad.py +++ b/tests/st/control/test_cont_grad.py @@ -69,6 +69,10 @@ def test_while_grad(): assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_const_param_grad(): class MyWhileNet(nn.Cell): def __init__(self): @@ -131,6 +135,10 @@ def test_while_with_variable_grad(): assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001) assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_param_forward(): class MyWhileNet(nn.Cell): def __init__(self): @@ -155,10 +163,8 @@ def test_while_with_param_forward(): end = Tensor(np.array(2), dtype=ms.int32) x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) graph_output = net(idx, end, x) - # pynative mode - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - pynative_output = net(idx, end, x) - assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) + expect = np.array([[[6, 8], [10, 12]], [[19, 22], [25, 28]]], dtype=np.int32) + assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001) def test_while_endless_case(): @@ -189,6 +195,10 @@ def test_while_endless_case(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard def test_while_with_param_grad(): class MyWhileNet(nn.Cell): def __init__(self):