fix gather nd for untest (#30037)

revert-31562-mean
ShenLiang 5 years ago committed by GitHub
parent a253a78a85
commit b6fd262951
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -114,9 +114,9 @@ class TestGatherNdOpWithHighRankSame(OpTest):
def setUp(self):
self.op_type = "gather_nd"
shape = (20, 9, 8, 1, 31)
shape = (5, 2, 3, 1, 10)
xnp = np.random.rand(*shape).astype("float64")
index = np.vstack([np.random.randint(0, s, size=150) for s in shape]).T
index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T
self.inputs = {'X': xnp, 'Index': index.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)]}
@ -133,13 +133,13 @@ class TestGatherNdOpWithHighRankDiff(OpTest):
def setUp(self):
self.op_type = "gather_nd"
shape = (20, 9, 8, 1, 31)
shape = (2, 3, 4, 1, 10)
xnp = np.random.rand(*shape).astype("float64")
index = np.vstack([np.random.randint(0, s, size=1000) for s in shape]).T
index_re = index.reshape([10, 5, 20, 5])
index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T
index_re = index.reshape([20, 5, 2, 5])
self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)].reshape([10, 5, 20])}
self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])}
def test_check_output(self):
self.check_output()

Loading…
Cancel
Save