Optimize argsort Op performance on GPU

* argsort op acceleration on GPU when the input size is equal to the length of the ‘axis’ dimension
revert-27520-disable_pr
LutaoChu 4 years ago committed by GitHub
parent 1d3b27cae8
commit f11a53ee76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/copy.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/sequence.h>
#include <thrust/sort.h> #include <thrust/sort.h>
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
@ -58,6 +60,16 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
} }
} }
template <typename T, typename IndType>
static __global__ void FillFlattenGrad(const T* dO, const IndType* indices,
int64_t size, T* dX) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < size; i += stride) {
dX[indices[i]] = dO[i];
}
}
template <typename T, typename IndType> template <typename T, typename IndType>
static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX, static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX,
IndType num_rows, IndType num_cols) { IndType num_rows, IndType num_cols) {
@ -193,6 +205,23 @@ void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO,
} }
template <typename T> template <typename T>
void ArgFlattenAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO,
const Tensor* indices, int64_t size, Tensor* dX) {
auto cu_stream = ctx.stream();
const int64_t block_size =
std::min(size, static_cast<int64_t>(ctx.GetMaxThreadsPerBlock()));
int64_t max_threads = ctx.GetMaxPhysicalThreadCount();
const int64_t max_blocks =
std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
const int64_t grid_size =
std::min(max_blocks, (size + block_size - 1) / block_size);
FillFlattenGrad<<<grid_size, block_size, 0, cu_stream>>>(
dO->data<T>(), indices->data<int64_t>(), size, dX->data<T>());
}
template <typename DeviceContext, typename T>
class ArgsortOpCUDAKernel : public framework::OpKernel<T> { class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
@ -205,8 +234,25 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
auto in_dims = input->dims(); auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int64_t numel = input->numel(); const T* in_data = input->data<T>();
int64_t groups = numel / in_dims[axis]; auto size = input->numel();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* ids_data = indices->mutable_data<int64_t>(ctx.GetPlace());
// Use thrust for parallel acceleration when the input size is equal to the
// length of the axis dimension.
// Compared to the following 'Special case for full sort', ascending sort is
// 34 times faster and descending sort is 31 times faster.
if (size == in_dims[axis]) {
thrust::sequence(thrust::device, ids_data, ids_data + size);
thrust::copy(thrust::device, in_data, in_data + size, out_data);
thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data);
if (descending) {
thrust::reverse(thrust::device, out_data, out_data + size);
thrust::reverse(thrust::device, ids_data, ids_data + size);
}
return;
}
// Special case for full sort, speedup ~190x. // Special case for full sort, speedup ~190x.
if (axis == -1 || axis + 1 == in_dims.size()) { if (axis == -1 || axis + 1 == in_dims.size()) {
@ -276,23 +322,28 @@ class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
dX->mutable_data<T>(ctx.GetPlace()); dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return; if (dO->numel() == 0) return;
auto in_dims = indices->dims(); auto in_dims = dX->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis; axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int64_t numel = indices->numel(); int64_t size = dX->numel();
const auto& dev_ctx = ctx.cuda_device_context();
// Parallel acceleration when the input size is equal to the length of the
// axis dimension.
// Compared to 'special case for full sort' below, the gradient calculation
// is 10 times faster.
if (size == in_dims[axis]) {
ArgFlattenAssign<T>(dev_ctx, dO, indices, size, dX);
return;
}
// Special case for full sort, speedup ~190x. // Special case for full sort, speedup ~190x.
if (axis == -1 || axis + 1 == in_dims.size()) { if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product( const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1]; const int64_t input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
ArgFullAssign<T, int64_t>(dev_ctx, dO, indices, dX, input_height, ArgFullAssign<T, int64_t>(dev_ctx, dO, indices, dX, input_height,
input_width); input_width);
} else { } else {
@ -316,7 +367,6 @@ class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
Tensor trans_ind; Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace()); trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
int ndims = trans.size(); int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
// Do transpose // Do transpose
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *dO, TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *dO,
&trans_dO, trans); &trans_dO, trans);
@ -345,11 +395,17 @@ class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
argsort, paddle::operators::ArgsortOpCUDAKernel<float>, argsort,
paddle::operators::ArgsortOpCUDAKernel<double>, paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::operators::ArgsortOpCUDAKernel<int>, float>,
paddle::operators::ArgsortOpCUDAKernel<int64_t>, paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>); double>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel<float>, argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel<float>,
paddle::operators::ArgsortGradOpCUDAKernel<double>, paddle::operators::ArgsortGradOpCUDAKernel<double>,

@ -348,57 +348,99 @@ class TestArgsortErrorOnGPU(TestArgsortErrorOnCPU):
class TestArgsort(unittest.TestCase): class TestArgsort(unittest.TestCase):
def init(self):
self.input_shape = [10000, ]
self.axis = 0
def setUp(self): def setUp(self):
self.init()
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
else: else:
self.place = core.CPUPlace() self.place = core.CPUPlace()
self.data = np.random.rand(2, 3, 4).astype("float32") self.data = np.random.rand(*self.input_shape)
def test_api_0(self): def test_api(self):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 4], dtype="float32") input = fluid.data(
output = paddle.argsort(x=input) name="input", shape=self.input_shape, dtype="float64")
exe = fluid.Executor(self.place)
result, = exe.run(feed={'input': self.data}, fetch_list=[output]) output = paddle.argsort(input, axis=self.axis)
np_result = np.argsort(self.data) output2 = paddle.argsort(input, axis=self.axis, descending=True)
self.assertEqual((result == np_result).all(), True)
def test_api_1(self):
with fluid.program_guard(fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 4], dtype="float32")
output = paddle.argsort(x=input, axis=1)
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
result, = exe.run(feed={'input': self.data}, fetch_list=[output]) result, result2 = exe.run(feed={'input': self.data},
np_result = np.argsort(self.data, axis=1) fetch_list=[output, output2])
np_result = np.argsort(self.data, axis=self.axis)
self.assertEqual((result == np_result).all(), True) self.assertEqual((result == np_result).all(), True)
np_result2 = np.argsort(-self.data, axis=self.axis)
self.assertEqual((result2 == np_result2).all(), True)
class TestArgsort2(TestArgsort):
def init(self):
self.input_shape = [10000, 1]
self.axis = 0
class TestArgsort3(TestArgsort):
def init(self):
self.input_shape = [1, 10000]
self.axis = 1
class TestArgsort4(TestArgsort):
def init(self):
self.input_shape = [2, 3, 4]
self.axis = 1
class TestArgsortImperative(unittest.TestCase):
def init(self):
self.input_shape = [10000, ]
self.axis = 0
class TestArgsortDygraph(unittest.TestCase):
def setUp(self): def setUp(self):
self.input_data = np.random.rand(10, 10) self.init()
self.input_data = np.random.rand(*self.input_shape)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0) self.place = core.CUDAPlace(0)
else: else:
self.place = core.CPUPlace() self.place = core.CPUPlace()
def test_api_0(self): def test_api(self):
paddle.disable_static(self.place) paddle.disable_static(self.place)
var_x = paddle.to_variable(self.input_data) var_x = paddle.to_tensor(self.input_data)
out = paddle.argsort(var_x) out = paddle.argsort(var_x, axis=self.axis)
self.assertEqual((np.argsort(self.input_data) == out.numpy()).all(), expect = np.argsort(self.input_data, axis=self.axis)
True) self.assertEqual((expect == out.numpy()).all(), True)
paddle.enable_static()
out2 = paddle.argsort(var_x, axis=self.axis, descending=True)
expect2 = np.argsort(-self.input_data, axis=self.axis)
self.assertEqual((expect2 == out2.numpy()).all(), True)
def test_api_1(self):
paddle.disable_static(self.place)
var_x = paddle.to_variable(self.input_data)
out = paddle.argsort(var_x, axis=-1)
self.assertEqual(
(np.argsort(
self.input_data, axis=-1) == out.numpy()).all(), True)
paddle.enable_static() paddle.enable_static()
class TestArgsortImperative2(TestArgsortImperative):
def init(self):
self.input_shape = [10000, 1]
self.axis = 0
class TestArgsortImperative3(TestArgsortImperative):
def init(self):
self.input_shape = [1, 10000]
self.axis = 1
class TestArgsortImperative2(TestArgsortImperative):
def init(self):
self.input_shape = [2, 3, 4]
self.axis = 1
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Loading…
Cancel
Save