|
|
|
@ -9,7 +9,7 @@ class TestTopkOp(OpTest):
|
|
|
|
|
k = 1
|
|
|
|
|
input = np.random.random((32, 84)).astype("float32")
|
|
|
|
|
output = np.ndarray((32, k))
|
|
|
|
|
indices = np.ndarray((32, k))
|
|
|
|
|
indices = np.ndarray((32, k)).astype("int64")
|
|
|
|
|
|
|
|
|
|
self.inputs = {'X': input}
|
|
|
|
|
self.attrs = {'k': k}
|
|
|
|
@ -32,7 +32,7 @@ class TestTopkOp3d(OpTest):
|
|
|
|
|
input = np.random.random((32, 2, 84)).astype("float32")
|
|
|
|
|
input_flat_2d = input.reshape(64, 84)
|
|
|
|
|
output = np.ndarray((64, k))
|
|
|
|
|
indices = np.ndarray((64, k)).astype("int")
|
|
|
|
|
indices = np.ndarray((64, k)).astype("int64")
|
|
|
|
|
|
|
|
|
|
# FIXME: should use 'X': input for a 3d input
|
|
|
|
|
self.inputs = {'X': input_flat_2d}
|
|
|
|
|