|
|
|
@ -12,11 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "cub/cub.cuh"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/top_k_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/cuda_device_function.h"
|
|
|
|
|
#include "paddle/fluid/platform/float16.h"
|
|
|
|
|
|
|
|
|
|
// set cub base traits in order to handle float16
|
|
|
|
|
namespace cub {
|
|
|
|
|
template <>
|
|
|
|
|
struct NumericTraits<paddle::platform::float16>
|
|
|
|
|
: BaseTraits<FLOATING_POINT, true, false, uint16_t,
|
|
|
|
|
paddle::platform::float16> {};
|
|
|
|
|
} // namespace cub
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -303,6 +312,157 @@ inline static int GetDesiredBlockDim(int dim) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Iter for move to next row
|
|
|
|
|
struct SegmentOffsetIter {
|
|
|
|
|
EIGEN_DEVICE_FUNC
|
|
|
|
|
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
|
|
|
|
|
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
|
|
|
|
|
return idx * num_cols_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int num_cols_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Iter using into a column
|
|
|
|
|
struct ColumnIndexIter {
|
|
|
|
|
explicit ColumnIndexIter(int num_cols) : num_cols_(num_cols) {}
|
|
|
|
|
|
|
|
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(
|
|
|
|
|
const Eigen::array<int, 1>& ix) const {
|
|
|
|
|
return ix[0] % num_cols_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int num_cols_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
__global__ void InitIndex(int64_t* indices, int num_rows, int num_cols) {
|
|
|
|
|
int col_id = threadIdx.x;
|
|
|
|
|
int row_id = blockIdx.x;
|
|
|
|
|
|
|
|
|
|
for (int j = row_id; j < num_rows; j += gridDim.x) {
|
|
|
|
|
for (int i = col_id; i < num_cols; i += blockDim.x) {
|
|
|
|
|
indices[j * num_cols + i] = i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool SortTopk(const platform::CUDADeviceContext& ctx,
|
|
|
|
|
const framework::Tensor* input_tensor, const size_t num_cols,
|
|
|
|
|
const size_t num_rows, size_t k, framework::Tensor* out_tensor,
|
|
|
|
|
framework::Tensor* indices_tensor) {
|
|
|
|
|
auto cu_stream = ctx.stream();
|
|
|
|
|
|
|
|
|
|
Tensor input_indices;
|
|
|
|
|
const std::vector<int64_t> dims = {static_cast<int64_t>(num_rows),
|
|
|
|
|
static_cast<int64_t>(num_cols)};
|
|
|
|
|
auto dim = framework::make_ddim(dims);
|
|
|
|
|
input_indices.Resize(dim);
|
|
|
|
|
// input_indices.Resize(num_rows*num_cols);
|
|
|
|
|
input_indices.mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
|
size_t temp_storage_bytes = -1;
|
|
|
|
|
|
|
|
|
|
auto ComputeBlockSize = [](int col) {
|
|
|
|
|
if (col > 512)
|
|
|
|
|
return 1024;
|
|
|
|
|
else if (col > 256 && col <= 512)
|
|
|
|
|
return 512;
|
|
|
|
|
else if (col > 128 && col <= 256)
|
|
|
|
|
return 256;
|
|
|
|
|
else if (col > 64 && col <= 128)
|
|
|
|
|
return 128;
|
|
|
|
|
else
|
|
|
|
|
return 64;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
int block_size = ComputeBlockSize(num_cols);
|
|
|
|
|
|
|
|
|
|
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
|
|
|
|
|
// actually, int num_rows < max_grid_size
|
|
|
|
|
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
|
|
|
|
|
// Init a index array
|
|
|
|
|
InitIndex<<<grid_size, block_size, 0, cu_stream>>>(
|
|
|
|
|
input_indices.data<int64_t>(), num_rows, num_cols);
|
|
|
|
|
|
|
|
|
|
// create iter for counting input
|
|
|
|
|
cub::CountingInputIterator<int> counting_iter(0);
|
|
|
|
|
// segment_offset is used for move to next row
|
|
|
|
|
cub::TransformInputIterator<int, SegmentOffsetIter,
|
|
|
|
|
cub::CountingInputIterator<int>>
|
|
|
|
|
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
|
|
|
|
|
|
|
|
|
|
T* sorted_values_ptr;
|
|
|
|
|
int64_t* sorted_indices_ptr;
|
|
|
|
|
|
|
|
|
|
Tensor temp_values;
|
|
|
|
|
Tensor temp_indices;
|
|
|
|
|
|
|
|
|
|
const T* input = input_tensor->data<T>();
|
|
|
|
|
T* values = out_tensor->data<T>();
|
|
|
|
|
int64_t* indices = indices_tensor->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
if (k == num_cols) {
|
|
|
|
|
// Doing a full sort.
|
|
|
|
|
sorted_values_ptr = values;
|
|
|
|
|
sorted_indices_ptr = indices;
|
|
|
|
|
} else {
|
|
|
|
|
temp_values.Resize(dim);
|
|
|
|
|
temp_indices.Resize(dim);
|
|
|
|
|
sorted_values_ptr = temp_values.mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
sorted_indices_ptr = temp_indices.mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get temp storage buffer size, maybe can allocate a fixed buffer to save
|
|
|
|
|
// time.
|
|
|
|
|
auto err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
|
|
|
|
nullptr, temp_storage_bytes, input, sorted_values_ptr,
|
|
|
|
|
input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
|
|
|
|
|
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
|
|
|
|
|
cu_stream);
|
|
|
|
|
if (err != cudaSuccess) {
|
|
|
|
|
LOG(ERROR)
|
|
|
|
|
<< "TopKOP failed as could not launch "
|
|
|
|
|
"cub::DeviceSegmentedRadixSort::SortPairsDescending to calculate "
|
|
|
|
|
"temp_storage_bytes, status: "
|
|
|
|
|
<< cudaGetErrorString(err);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
Tensor temp_storage;
|
|
|
|
|
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
|
|
|
|
|
|
|
|
|
|
err = cub::DeviceSegmentedRadixSort::SortPairsDescending(
|
|
|
|
|
temp_storage.data<uint8_t>(), temp_storage_bytes, input,
|
|
|
|
|
sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
|
|
|
|
|
num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
|
|
|
|
|
0, sizeof(T) * 8, cu_stream);
|
|
|
|
|
if (err != cudaSuccess) {
|
|
|
|
|
LOG(ERROR)
|
|
|
|
|
<< "TopKOP failed as could not launch "
|
|
|
|
|
"cub::DeviceSegmentedRadixSort::SortPairsDescending to sort input, "
|
|
|
|
|
"temp_storage_bytes: "
|
|
|
|
|
<< temp_storage_bytes << ", status: " << cudaGetErrorString(err);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto& dev = *ctx.eigen_device();
|
|
|
|
|
if (k < num_cols) {
|
|
|
|
|
// copy sliced data to output.
|
|
|
|
|
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
|
|
|
|
|
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
|
|
|
|
|
auto e_indices = EigenMatrix<int64_t>::From(*indices_tensor, dim);
|
|
|
|
|
auto e_tmp_indices = EigenMatrix<int64_t>::From(temp_indices);
|
|
|
|
|
|
|
|
|
|
std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(k)};
|
|
|
|
|
auto dim = framework::make_ddim(odims);
|
|
|
|
|
auto e_values = EigenMatrix<T>::From(*out_tensor, dim);
|
|
|
|
|
auto e_tmp_values = EigenMatrix<T>::From(temp_values);
|
|
|
|
|
|
|
|
|
|
e_indices.device(dev) = e_tmp_indices.slice(slice_indices, slice_sizes);
|
|
|
|
|
e_values.device(dev) = e_tmp_values.slice(slice_indices, slice_sizes);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
|
|
|
|
|
case (dim): { \
|
|
|
|
|
constexpr auto kBlockDim = (dim); \
|
|
|
|
@ -340,13 +500,24 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
// FIXME(typhoonzero): data is always converted to type T?
|
|
|
|
|
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
framework::DDim inputdims = input->dims();
|
|
|
|
|
const size_t input_height = framework::product(
|
|
|
|
|
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
|
|
|
|
|
const size_t input_width = inputdims[inputdims.size() - 1];
|
|
|
|
|
const auto& dev_ctx = ctx.cuda_device_context();
|
|
|
|
|
|
|
|
|
|
if ((input_width <= 1024 || k >= 128 || k == input_width)) {
|
|
|
|
|
if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
|
|
|
|
|
indices)) {
|
|
|
|
|
// Successed, return.
|
|
|
|
|
return;
|
|
|
|
|
} else {
|
|
|
|
|
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use "
|
|
|
|
|
"default topk kernel.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
|
if (k > input_width) k = input_width;
|
|
|
|
|
|
|
|
|
|
// NOTE: pass lds and dim same to input width.
|
|
|
|
@ -354,7 +525,6 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
// TODO(typhoonzero): refine this kernel.
|
|
|
|
|
const int kMaxHeight = 2048;
|
|
|
|
|
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
|
|
|
|
|
auto& dev_ctx = ctx.cuda_device_context();
|
|
|
|
|
switch (GetDesiredBlockDim(input_width)) {
|
|
|
|
|
FIXED_BLOCK_DIM(
|
|
|
|
|
KeMatrixTopK<T, 5,
|
|
|
|
|