|
|
|
@ -63,28 +63,28 @@ class TestTopkOp(OpTest):
|
|
|
|
|
self.check_grad(set(['X']), 'Out')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTopOp1(TestTopkOp):
|
|
|
|
|
class TestTopkOp1(TestTopkOp):
|
|
|
|
|
def init_args(self):
|
|
|
|
|
self.k = 3
|
|
|
|
|
self.axis = 0
|
|
|
|
|
self.largest = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTopOp2(TestTopkOp):
|
|
|
|
|
class TestTopkOp2(TestTopkOp):
|
|
|
|
|
def init_args(self):
|
|
|
|
|
self.k = 3
|
|
|
|
|
self.axis = 0
|
|
|
|
|
self.largest = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTopOp3(TestTopkOp):
|
|
|
|
|
class TestTopkOp3(TestTopkOp):
|
|
|
|
|
def init_args(self):
|
|
|
|
|
self.k = 4
|
|
|
|
|
self.axis = 0
|
|
|
|
|
self.largest = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTopOp4(TestTopkOp):
|
|
|
|
|
class TestTopkOp4(TestTopkOp):
|
|
|
|
|
def init_args(self):
|
|
|
|
|
self.k = 4
|
|
|
|
|
self.axis = 0
|
|
|
|
@ -189,6 +189,8 @@ class TestTopKAPI(unittest.TestCase):
|
|
|
|
|
result1 = paddle.topk(input_tensor, k=2)
|
|
|
|
|
result2 = paddle.topk(input_tensor, k=2, axis=-1)
|
|
|
|
|
result3 = paddle.topk(input_tensor, k=k_tensor, axis=1)
|
|
|
|
|
self.assertEqual(result3[0].shape, (6, -1, 8))
|
|
|
|
|
self.assertEqual(result3[1].shape, (6, -1, 8))
|
|
|
|
|
result4 = paddle.topk(input_tensor, k=2, axis=1, largest=False)
|
|
|
|
|
result5 = paddle.topk(input_tensor, k=2, axis=-1, largest=False)
|
|
|
|
|
result6 = paddle.topk(large_input_tensor, k=1, axis=-1)
|
|
|
|
@ -239,6 +241,15 @@ class TestTopKAPI(unittest.TestCase):
|
|
|
|
|
self.run_dygraph(place)
|
|
|
|
|
self.run_static(place)
|
|
|
|
|
|
|
|
|
|
def test_errors(self):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
x = paddle.to_tensor([1, 2, 3])
|
|
|
|
|
with self.assertRaises(BaseException):
|
|
|
|
|
paddle.topk(x, k=-1)
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(BaseException):
|
|
|
|
|
paddle.topk(x, k=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|