|
|
|
@ -64,10 +64,10 @@ def test_op1():
|
|
|
|
|
update = Tensor(np.array([1.0, 2.2]), mstype.float32)
|
|
|
|
|
|
|
|
|
|
scatter_nd_update = ScatterNdUpdate1()
|
|
|
|
|
output = scatter_nd_update(indices, update)
|
|
|
|
|
print("output:\n", output)
|
|
|
|
|
scatter_nd_update(indices, update)
|
|
|
|
|
print("x:\n", scatter_nd_update.x.default_input)
|
|
|
|
|
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
|
|
|
|
|
assert np.allclose(output.asnumpy(), np.array(expect, np.float))
|
|
|
|
|
assert np.allclose(scatter_nd_update.x.default_input.asnumpy(), np.array(expect, np.float))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
@ -78,10 +78,10 @@ def test_op2():
|
|
|
|
|
update = Tensor(np.array([9, 10, 11, 12]), mstype.float32)
|
|
|
|
|
|
|
|
|
|
scatter_nd_update = ScatterNdUpdate2()
|
|
|
|
|
output = scatter_nd_update(indices, update)
|
|
|
|
|
print("output:\n", output)
|
|
|
|
|
scatter_nd_update(indices, update)
|
|
|
|
|
print("x:\n", scatter_nd_update.x.default_input)
|
|
|
|
|
expect = [1, 11, 3, 10, 9, 6, 7, 12]
|
|
|
|
|
assert np.allclose(output.asnumpy(), np.array(expect, dtype=float))
|
|
|
|
|
assert np.allclose(scatter_nd_update.x.default_input.asnumpy(), np.array(expect, dtype=float))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0
|
|
|
|
@ -95,10 +95,10 @@ def test_op3():
|
|
|
|
|
[7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32)
|
|
|
|
|
|
|
|
|
|
scatter_nd_update = ScatterNdUpdate3()
|
|
|
|
|
output = scatter_nd_update(indices, update)
|
|
|
|
|
print("output:\n", output)
|
|
|
|
|
scatter_nd_update(indices, update)
|
|
|
|
|
print("x:\n", scatter_nd_update.x.default_input)
|
|
|
|
|
expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
|
|
|
|
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
|
|
|
|
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
|
|
|
|
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
|
|
|
|
|
assert np.allclose(output.asnumpy(), np.array(expect, dtype=float))
|
|
|
|
|
assert np.allclose(scatter_nd_update.x.default_input.asnumpy(), np.array(expect, dtype=float))
|
|
|
|
|