@ -58,11 +58,12 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}
// Default use ascending sort
// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
template <typename T, typename IndType>
void ArgFullSortAscending (const platform::CUDADeviceContext& ctx,
const Tensor* input, Tensor* output, Tensor* indices,
const IndType num_rows, const IndType num_cols ) {
void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
Tensor* output, Tensor* indices, const IndType num_rows,
const IndType num_cols, const bool descending ) {
auto cu_stream = ctx.stream();
Tensor input_indices;
@ -113,12 +114,20 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx,
cub::CountingInputIterator<IndType>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
auto err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
cudaError_t err;
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
} else {
err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
}
PADDLE_ENFORCE_CUDA_SUCCESS(
err,
"ArgSortOP failed as could not launch "
@ -129,11 +138,19 @@ void ArgFullSortAscending(const platform::CUDADeviceContext& ctx,
Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
if (descending) {
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
} else {
err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
}
PADDLE_ENFORCE_CUDA_SUCCESS(
err,
@ -151,6 +168,7 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
int axis = ctx.Attr<int>("axis");
bool descending = ctx.Attr<bool>("descending");
auto in_dims = input->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
@ -164,14 +182,8 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
if (input_width < INT_MAX && input_height < INT_MAX) {
ArgFullSortAscending<T, int>(dev_ctx, input, output, indices,
static_cast<int>(input_height),
static_cast<int>(input_width));
} else {
ArgFullSortAscending<T, int64_t>(dev_ctx, input, output, indices,
input_height, input_width);
}
ArgFullSort<T, int64_t>(dev_ctx, input, output, indices, input_height,
input_width, descending);
} else {
// if not full sort, do transpose first
std::vector<int> trans;
@ -205,29 +217,15 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
T* out_data = output->mutable_data<T>(ctx.GetPlace());
Tensor tmp_indices;
if (input_height < INT_MAX && input_width < INT_MAX) {
// temp indices for sorting
tmp_indices.mutable_data<int>(trans_dims, ctx.GetPlace());
indices->mutable_data<int>(ctx.GetPlace());
ArgFullSortAscending<T, int>(
dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
static_cast<int>(input_height), static_cast<int>(input_width));
TransCompute<platform::CUDADeviceContext, int>(
ndims, dev_ctx, tmp_indices, indices, trans);
} else {
// temp indices for sorting
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
indices->mutable_data<int64_t>(ctx.GetPlace());
ArgFullSortAscending<T, int64_t>(dev_ctx, &trans_inp, &tmp_out,
&tmp_indices, input_height,
input_width);
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
}
// temp indices for sorting
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
indices->mutable_data<int64_t>(ctx.GetPlace());
ArgFullSort<T, int64_t>(dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
input_height, input_width, descending);
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, tmp_indices, indices, trans);
// transpose back
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out,
output, trans);