|
|
|
@ -114,9 +114,9 @@ class TestGatherNdOpWithHighRankSame(OpTest):
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "gather_nd"
|
|
|
|
|
shape = (20, 9, 8, 1, 31)
|
|
|
|
|
shape = (5, 2, 3, 1, 10)
|
|
|
|
|
xnp = np.random.rand(*shape).astype("float64")
|
|
|
|
|
index = np.vstack([np.random.randint(0, s, size=150) for s in shape]).T
|
|
|
|
|
index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T
|
|
|
|
|
|
|
|
|
|
self.inputs = {'X': xnp, 'Index': index.astype("int32")}
|
|
|
|
|
self.outputs = {'Out': xnp[tuple(index.T)]}
|
|
|
|
@ -133,13 +133,13 @@ class TestGatherNdOpWithHighRankDiff(OpTest):
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "gather_nd"
|
|
|
|
|
shape = (20, 9, 8, 1, 31)
|
|
|
|
|
shape = (2, 3, 4, 1, 10)
|
|
|
|
|
xnp = np.random.rand(*shape).astype("float64")
|
|
|
|
|
index = np.vstack([np.random.randint(0, s, size=1000) for s in shape]).T
|
|
|
|
|
index_re = index.reshape([10, 5, 20, 5])
|
|
|
|
|
index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T
|
|
|
|
|
index_re = index.reshape([20, 5, 2, 5])
|
|
|
|
|
|
|
|
|
|
self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
|
|
|
|
|
self.outputs = {'Out': xnp[tuple(index.T)].reshape([10, 5, 20])}
|
|
|
|
|
self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|