|
|
|
@ -66,9 +66,9 @@ class TestScatterNdAddSimpleOp(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "scatter_nd_add"
|
|
|
|
|
#ref_np = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]).astype("float32")
|
|
|
|
|
ref_np = np.random.random([100]).astype("float32")
|
|
|
|
|
ref_np = np.random.random([100]).astype("float64")
|
|
|
|
|
index_np = np.random.randint(0, 100, [100, 1]).astype("int32")
|
|
|
|
|
updates_np = np.random.random([100]).astype("float32")
|
|
|
|
|
updates_np = np.random.random([100]).astype("float64")
|
|
|
|
|
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
|
|
|
|
|
#expect_np = [ 0. 23. 12. 14. 4. 17. 6. 7. 8.]
|
|
|
|
|
|
|
|
|
@ -89,10 +89,10 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "scatter_nd_add"
|
|
|
|
|
ref_np = np.array([[65, 17], [-14, -25]]).astype("float32")
|
|
|
|
|
ref_np = np.array([[65, 17], [-14, -25]]).astype("float64")
|
|
|
|
|
index_np = np.array([[], []]).astype("int32")
|
|
|
|
|
updates_np = np.array([[[-1, -2], [1, 2]],
|
|
|
|
|
[[3, 4], [-3, -4]]]).astype("float32")
|
|
|
|
|
[[3, 4], [-3, -4]]]).astype("float64")
|
|
|
|
|
|
|
|
|
|
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
|
|
|
|
|
#expect_np = [[67, 19], [-16, -27]]
|
|
|
|
@ -115,12 +115,12 @@ class TestScatterNdAddWithHighRankSame(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "scatter_nd_add"
|
|
|
|
|
shape = (10, 9, 8, 1, 15)
|
|
|
|
|
ref_np = np.random.rand(*shape).astype("float32")
|
|
|
|
|
ref_np = np.random.rand(*shape).astype("float64")
|
|
|
|
|
index_np = np.vstack(
|
|
|
|
|
[np.random.randint(
|
|
|
|
|
0, s, size=150) for s in shape]).T.astype("int32")
|
|
|
|
|
update_shape = judge_update_shape(ref_np, index_np)
|
|
|
|
|
updates_np = np.random.rand(*update_shape).astype("float32")
|
|
|
|
|
updates_np = np.random.rand(*update_shape).astype("float64")
|
|
|
|
|
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
|
|
|
|
|
|
|
|
|
|
self.inputs = {'X': ref_np, 'Index': index_np, 'Updates': updates_np}
|
|
|
|
|