|
|
|
@ -69,6 +69,10 @@ def test_while_grad():
|
|
|
|
|
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_arm_ascend_training
|
|
|
|
|
@pytest.mark.platform_x86_ascend_training
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_while_with_const_param_grad():
|
|
|
|
|
class MyWhileNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
@ -131,6 +135,10 @@ def test_while_with_variable_grad():
|
|
|
|
|
assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
|
|
|
|
|
assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_arm_ascend_training
|
|
|
|
|
@pytest.mark.platform_x86_ascend_training
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_while_with_param_forward():
|
|
|
|
|
class MyWhileNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
@ -155,10 +163,8 @@ def test_while_with_param_forward():
|
|
|
|
|
end = Tensor(np.array(2), dtype=ms.int32)
|
|
|
|
|
x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
|
|
|
|
|
graph_output = net(idx, end, x)
|
|
|
|
|
# pynative mode
|
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
|
|
|
|
pynative_output = net(idx, end, x)
|
|
|
|
|
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
expect = np.array([[[6, 8], [10, 12]], [[19, 22], [25, 28]]], dtype=np.int32)
|
|
|
|
|
assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_while_endless_case():
|
|
|
|
@ -189,6 +195,10 @@ def test_while_endless_case():
|
|
|
|
|
pynative_output = net(idx, end, x)
|
|
|
|
|
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
|
@pytest.mark.platform_arm_ascend_training
|
|
|
|
|
@pytest.mark.platform_x86_ascend_training
|
|
|
|
|
@pytest.mark.env_onecard
|
|
|
|
|
def test_while_with_param_grad():
|
|
|
|
|
class MyWhileNet(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|