diff --git a/tests/st/ops/gpu/test_scatter_nd.py b/tests/st/ops/gpu/test_scatter_nd.py index b201c7be2c..061fb697ae 100644 --- a/tests/st/ops/gpu/test_scatter_nd.py +++ b/tests/st/ops/gpu/test_scatter_nd.py @@ -19,7 +19,6 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P -context.set_context(mode=context.GRAPH_MODE, device_target="GPU") class Net(nn.Cell): def __init__(self, _shape): @@ -30,6 +29,7 @@ class Net(nn.Cell): def construct(self, indices, update): return self.scatternd(indices, update, self.shape) + def scatternd_net(indices, update, _shape, expect): scatternd = Net(_shape) output = scatternd(Tensor(indices), Tensor(update)) @@ -38,13 +38,49 @@ def scatternd_net(indices, update, _shape, expect): assert np.all(diff < error) assert np.all(-diff < error) +def scatternd_positive(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + arr_indices = np.array([[0, 1], [1, 1], [0, 1], [0, 1], [0, 1]]).astype(np.int32) + arr_update = np.array([3.2, 1.1, 5.3, -2.2, -1.0]).astype(nptype) + shape = (2, 2) + expect = np.array([[0., 5.3], + [0., 1.1]]).astype(nptype) + scatternd_net(arr_indices, arr_update, shape, expect) + +def scatternd_negative(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + arr_indices = np.array([[1, 0], [1, 1], [1, 0], [1, 0], [1, 0]]).astype(np.int32) + arr_update = np.array([-13.4, -3.1, 5.1, -12.1, -1.0]).astype(nptype) + shape = (2, 2) + expect = np.array([[0., 0.], + [-21.4, -3.1]]).astype(nptype) + scatternd_net(arr_indices, arr_update, shape, expect) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_traning @pytest.mark.env_onecard -def test_scatternd(): - arr_indices = np.array([[0, 1], [1, 1]]).astype(np.int32) - arr_update = np.array([3.2, 1.1]).astype(np.float32) - shape = (2, 2) - expect = np.array([[0., 3.2], - [0., 1.1]]) - scatternd_net(arr_indices, arr_update, shape, expect) +def test_scatternd_float32(): + scatternd_positive(np.float32) + scatternd_negative(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_scatternd_float16(): + scatternd_positive(np.float16) + scatternd_negative(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_scatternd_int16(): + scatternd_positive(np.int16) + scatternd_negative(np.int16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_traning +@pytest.mark.env_onecard +def test_scatternd_uint8(): + scatternd_positive(np.uint8)