@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k,
int grid_dim, int num) {
__shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
const int tid = threadIdx.x;
const int warp = threadIdx.x / 32;
const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) {
output += i * output_stride;
indices += i * k;
int top_num = k;
__shared__ int maxid[BlockSize / 2];
T* out = output + i * output_stride;
int64_t* inds = indices + i * k;
Pair<T> topk[MaxLength];
int beam = MaxLength;
Pair<T> max;
bool is_empty = false;
bool firststep = true;
for (int k = 0; k < MaxLength; k ++) {
topk[k ].set(-INFINITY, -1);
for (int j = 0; j < MaxLength; j ++) {
topk[j ].set(-INFINITY, -1);
}
while (k ) {
while (top_num ) {
ThreadGetTopK<T, MaxLength, BlockSize>(
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output ,
&indices, &beam, &k , tid, warp);
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds ,
&beam, &top_num , tid, warp);
}
}
}
@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
size_t k = static_cast<int>(ctx.Attr<int>("k"));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T?
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
size_t input_height = input->dims()[0];
size_t input_width = input->dims()[1];
framework::DDim inputdims = input->dims();
const size_t input_height = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t input_width = inputdims[inputdims.size() - 1];
if (k > input_width) k = input_width;
// NOTE: pass lds and dim same to input width.
@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
auto& dev_ctx = ctx.cuda_device_context();
switch (GetDesiredBlockDim(input_width)) {
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, output->dims()[1], indices_data, input_data,
input_width, input_width, static_cast<int>(k), gridx,
input_height));
output_data, k, indices_data, input_data, input_width,
input_width, static_cast<int>(k), gridx, input_height));
default:
PADDLE_THROW("Error");
}