Improve argsort performance. (#21267)

* Improve argsort performance.

- Give 200000 data to compute argsort on v100,
can speed up ~190x
before opt cost: 0.53s
after opt cost:0.0027s

- Add fp16 support

* Refine error message
* Refine code

test=develop

Signed-off-by: zhaoyuchen <zhaoyuchen01@baidu.com>
revert-21172-masked_select_api
zhaoyuchen2018 5 years ago committed by GitHub
parent 7fcaa39b36
commit 08c19c585d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -17,12 +17,14 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
class TestArgsortOp(OpTest):
def setUp(self):
self.init_axis()
x = np.random.random((2, 3, 4, 5, 10)).astype("float32")
self.init_datatype()
x = np.random.random((2, 3, 4, 5, 10)).astype(self.dtype)
if self.axis < 0:
self.axis = self.axis + len(x.shape)
self.indices = np.argsort(x, kind='quicksort', axis=self.axis)
@ -35,6 +37,9 @@ class TestArgsortOp(OpTest):
def init_axis(self):
self.axis = -1
def init_datatype(self):
self.dtype = "float32"
def test_check_output(self):
self.check_output()
@ -49,10 +54,54 @@ class TestArgsortOpAxis1(TestArgsortOp):
self.axis = 1
class TestArgsortOpAxis2(TestArgsortOp):
def init_axis(self):
self.axis = 2
class TestArgsortOpAxisNeg1(TestArgsortOp):
def init_axis(self):
self.axis = -1
class TestArgsortOpAxisNeg2(TestArgsortOp):
def init_axis(self):
self.axis = -2
class TestArgsortOpFP16(TestArgsortOp):
def init_datatype(self):
if core.is_compiled_with_cuda():
self.dtype = 'float16'
def test_check_output(self):
pass
def test_check_output_with_place(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
class TestArgsortOpFP16Axis0(TestArgsortOpFP16):
def init_axis(self):
self.axis = 0
class TestArgsortOpFP16Axis2(TestArgsortOpFP16):
def init_axis(self):
self.axis = 2
class TestArgsortOpFP16AxisNeg2(TestArgsortOpFP16):
def init_axis(self):
self.axis = -2
class TestArgsortOpFP16Axis4Neg4(TestArgsortOpFP16):
def init_axis(self):
self.axis = -4
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save