You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
359 lines
13 KiB
359 lines
13 KiB
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#include <thrust/execution_policy.h>
|
|
#include <thrust/sort.h>
|
|
#include "cub/cub.cuh"
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
#include "paddle/fluid/operators/argsort_op.h"
|
|
#include "paddle/fluid/operators/transpose_op.h"
|
|
#include "paddle/fluid/platform/cuda_device_function.h"
|
|
#include "paddle/fluid/platform/cuda_primitives.h"
|
|
|
|
// set cub base traits in order to handle float16
|
|
namespace cub {
|
|
template <>
|
|
struct NumericTraits<paddle::platform::float16>
|
|
: BaseTraits<FLOATING_POINT, true, false, uint16_t,
|
|
paddle::platform::float16> {};
|
|
} // namespace cub
|
|
|
|
namespace paddle {
|
|
namespace operators {
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
// Iter for move to next row
|
|
struct SegmentOffsetIter {
|
|
EIGEN_DEVICE_FUNC
|
|
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
|
|
return idx * num_cols_;
|
|
}
|
|
|
|
int num_cols_;
|
|
};
|
|
|
|
template <typename T>
|
|
static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
|
|
int col_id = threadIdx.x;
|
|
int row_id = blockIdx.x;
|
|
|
|
for (T j = row_id; j < num_rows; j += gridDim.x) {
|
|
for (T i = col_id; i < num_cols; i += blockDim.x) {
|
|
indices[j * num_cols + i] = i;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T, typename IndType>
|
|
static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX,
|
|
IndType num_rows, IndType num_cols) {
|
|
int col_id = threadIdx.x;
|
|
int row_id = blockIdx.x;
|
|
|
|
for (IndType j = row_id; j < num_rows; j += gridDim.x) {
|
|
for (IndType i = col_id; i < num_cols; i += blockDim.x) {
|
|
dX[j * num_cols + indices[j * num_cols + i]] = dO[j * num_cols + i];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sort by flag descending, True: descending. False: Ascending.
|
|
// Default is false.
|
|
template <typename T, typename IndType>
|
|
void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
|
|
Tensor* output, Tensor* indices, const IndType num_rows,
|
|
const IndType num_cols, const bool descending) {
|
|
auto cu_stream = ctx.stream();
|
|
|
|
Tensor input_indices;
|
|
|
|
const std::vector<IndType> dims = {num_rows, num_cols};
|
|
auto dim = framework::make_ddim(dims);
|
|
input_indices.Resize(dim);
|
|
input_indices.mutable_data<IndType>(ctx.GetPlace());
|
|
|
|
size_t temp_storage_bytes = -1;
|
|
|
|
auto ComputeBlockSize = [](IndType col) {
|
|
if (col > 512)
|
|
return 1024;
|
|
else if (col > 256 && col <= 512)
|
|
return 512;
|
|
else if (col > 128 && col <= 256)
|
|
return 256;
|
|
else if (col > 64 && col <= 128)
|
|
return 128;
|
|
else
|
|
return 64;
|
|
};
|
|
|
|
int block_size = ComputeBlockSize(num_cols);
|
|
|
|
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
|
|
// actually, int num_rows < max_grid_size
|
|
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
|
|
// Init a index array
|
|
FillIndex<<<grid_size, block_size, 0, cu_stream>>>(
|
|
input_indices.data<IndType>(), num_rows, num_cols);
|
|
|
|
T* sorted_out_ptr;
|
|
IndType* sorted_indices_ptr;
|
|
|
|
const T* inp = input->data<T>();
|
|
T* out = output->mutable_data<T>(ctx.GetPlace());
|
|
IndType* ind = indices->mutable_data<IndType>(ctx.GetPlace());
|
|
|
|
sorted_out_ptr = out;
|
|
sorted_indices_ptr = ind;
|
|
|
|
// create iter for counting input
|
|
cub::CountingInputIterator<IndType> counting_iter(0);
|
|
// segment_offset is used for move to next row
|
|
cub::TransformInputIterator<IndType, SegmentOffsetIter,
|
|
cub::CountingInputIterator<IndType>>
|
|
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
|
|
|
|
cudaError_t err;
|
|
if (descending) {
|
|
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
|
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
|
|
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
|
|
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
|
|
cu_stream);
|
|
} else {
|
|
err = cub::DeviceSegmentedRadixSort::SortPairs(
|
|
nullptr, temp_storage_bytes, inp, sorted_out_ptr,
|
|
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
|
|
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
|
|
cu_stream);
|
|
}
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(err);
|
|
|
|
Tensor temp_storage;
|
|
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
|
|
|
|
if (descending) {
|
|
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
|
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
|
|
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
|
|
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
|
|
cu_stream);
|
|
} else {
|
|
err = cub::DeviceSegmentedRadixSort::SortPairs(
|
|
temp_storage.data<uint8_t>(), temp_storage_bytes, inp, sorted_out_ptr,
|
|
input_indices.data<IndType>(), sorted_indices_ptr, num_cols * num_rows,
|
|
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
|
|
cu_stream);
|
|
}
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(err);
|
|
}
|
|
|
|
template <typename T, typename IndType>
|
|
void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO,
|
|
const Tensor* indices, Tensor* dX, const IndType num_rows,
|
|
const IndType num_cols) {
|
|
auto cu_stream = ctx.stream();
|
|
|
|
auto ComputeBlockSize = [](IndType col) {
|
|
if (col > 512)
|
|
return 1024;
|
|
else if (col > 256 && col <= 512)
|
|
return 512;
|
|
else if (col > 128 && col <= 256)
|
|
return 256;
|
|
else if (col > 64 && col <= 128)
|
|
return 128;
|
|
else
|
|
return 64;
|
|
};
|
|
|
|
int block_size = ComputeBlockSize(num_cols);
|
|
|
|
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
|
|
// actually, int num_rows < max_grid_size
|
|
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
|
|
FillGrad<<<grid_size, block_size, 0, cu_stream>>>(
|
|
dO->data<T>(), indices->data<IndType>(), dX->data<T>(), num_rows,
|
|
num_cols);
|
|
}
|
|
|
|
template <typename T>
|
|
class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
|
|
public:
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
auto* input = ctx.Input<Tensor>("X");
|
|
auto* output = ctx.Output<Tensor>("Out");
|
|
auto* indices = ctx.Output<Tensor>("Indices");
|
|
int axis = ctx.Attr<int>("axis");
|
|
bool descending = ctx.Attr<bool>("descending");
|
|
|
|
auto in_dims = input->dims();
|
|
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
|
|
|
|
int64_t numel = input->numel();
|
|
int64_t groups = numel / in_dims[axis];
|
|
|
|
// Special case for full sort, speedup ~190x.
|
|
if (axis == -1 || axis + 1 == in_dims.size()) {
|
|
const int64_t input_height = framework::product(
|
|
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
|
|
const int64_t input_width = in_dims[in_dims.size() - 1];
|
|
const auto& dev_ctx = ctx.cuda_device_context();
|
|
ArgFullSort<T, int64_t>(dev_ctx, input, output, indices, input_height,
|
|
input_width, descending);
|
|
} else {
|
|
// if not full sort, do transpose first
|
|
std::vector<int> trans;
|
|
for (int i = 0; i < axis; i++) {
|
|
trans.push_back(i);
|
|
}
|
|
trans.push_back(in_dims.size() - 1);
|
|
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
|
|
trans.push_back(i);
|
|
}
|
|
trans.push_back(axis);
|
|
framework::DDim trans_dims(in_dims);
|
|
for (int i = 0; i < trans.size(); i++) {
|
|
trans_dims[i] = in_dims[trans[i]];
|
|
}
|
|
|
|
Tensor trans_inp;
|
|
T* trans_inp_data = trans_inp.mutable_data<T>(trans_dims, ctx.GetPlace());
|
|
int ndims = trans.size();
|
|
const auto& dev_ctx = ctx.cuda_device_context();
|
|
// Do transpose
|
|
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
|
|
&trans_inp, trans);
|
|
|
|
const int64_t input_height = framework::product(
|
|
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
|
|
const int64_t input_width = trans_dims[trans_dims.size() - 1];
|
|
|
|
Tensor tmp_out;
|
|
tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
|
|
T* out_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
Tensor tmp_indices;
|
|
// temp indices for sorting
|
|
tmp_indices.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
|
|
indices->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
ArgFullSort<T, int64_t>(dev_ctx, &trans_inp, &tmp_out, &tmp_indices,
|
|
input_height, input_width, descending);
|
|
|
|
TransCompute<platform::CUDADeviceContext, int64_t>(
|
|
ndims, dev_ctx, tmp_indices, indices, trans);
|
|
// transpose back
|
|
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out,
|
|
output, trans);
|
|
return;
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
public:
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
auto* indices = ctx.Input<Tensor>("Indices");
|
|
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
auto* dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
int axis = ctx.Attr<int>("axis");
|
|
|
|
dX->mutable_data<T>(ctx.GetPlace());
|
|
auto dxt = framework::EigenVector<T>::Flatten(*dX);
|
|
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
|
|
.eigen_device();
|
|
dxt.device(place) = dxt.constant(static_cast<T>(0));
|
|
if (dO->numel() == 0) return;
|
|
|
|
auto in_dims = indices->dims();
|
|
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
|
|
|
|
int64_t numel = indices->numel();
|
|
|
|
// Special case for full sort, speedup ~190x.
|
|
if (axis == -1 || axis + 1 == in_dims.size()) {
|
|
const int64_t input_height = framework::product(
|
|
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
|
|
const int64_t input_width = in_dims[in_dims.size() - 1];
|
|
const auto& dev_ctx = ctx.cuda_device_context();
|
|
ArgFullAssign<T, int64_t>(dev_ctx, dO, indices, dX, input_height,
|
|
input_width);
|
|
} else {
|
|
// if not full sort, do transpose first
|
|
std::vector<int> trans;
|
|
for (int i = 0; i < axis; i++) {
|
|
trans.push_back(i);
|
|
}
|
|
trans.push_back(in_dims.size() - 1);
|
|
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
|
|
trans.push_back(i);
|
|
}
|
|
trans.push_back(axis);
|
|
framework::DDim trans_dims(in_dims);
|
|
for (int i = 0; i < trans.size(); i++) {
|
|
trans_dims[i] = in_dims[trans[i]];
|
|
}
|
|
|
|
Tensor trans_dO;
|
|
trans_dO.mutable_data<T>(trans_dims, ctx.GetPlace());
|
|
Tensor trans_ind;
|
|
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
|
|
int ndims = trans.size();
|
|
const auto& dev_ctx = ctx.cuda_device_context();
|
|
// Do transpose
|
|
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *dO,
|
|
&trans_dO, trans);
|
|
TransCompute<platform::CUDADeviceContext, int64_t>(
|
|
ndims, dev_ctx, *indices, &trans_ind, trans);
|
|
|
|
const int64_t input_height = framework::product(
|
|
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
|
|
const int64_t input_width = trans_dims[trans_dims.size() - 1];
|
|
|
|
Tensor tmp_out;
|
|
tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
|
|
|
|
ArgFullAssign<T, int64_t>(dev_ctx, &trans_dO, &trans_ind, &tmp_out,
|
|
input_height, input_width);
|
|
|
|
// transpose back
|
|
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out, dX,
|
|
trans);
|
|
return;
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace operators
|
|
} // namespace paddle
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
argsort, paddle::operators::ArgsortOpCUDAKernel<float>,
|
|
paddle::operators::ArgsortOpCUDAKernel<double>,
|
|
paddle::operators::ArgsortOpCUDAKernel<int>,
|
|
paddle::operators::ArgsortOpCUDAKernel<int64_t>,
|
|
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel<float>,
|
|
paddle::operators::ArgsortGradOpCUDAKernel<double>,
|
|
paddle::operators::ArgsortGradOpCUDAKernel<int>,
|
|
paddle::operators::ArgsortGradOpCUDAKernel<int64_t>,
|
|
paddle::operators::ArgsortGradOpCUDAKernel<paddle::platform::float16>);
|