|
|
@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
|
|
|
|
* 3. go to the second setp, until one thread's topk value is null;
|
|
|
|
* 3. go to the second setp, until one thread's topk value is null;
|
|
|
|
* 4. go to the first setp, until get the topk value.
|
|
|
|
* 4. go to the first setp, until get the topk value.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, int MaxLength, int BlockSize>
|
|
|
|
template <typename T, int MaxLength, int BlockSize>
|
|
|
|
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
|
|
|
|
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
|
|
|
|
const T* src, int lds, int dim, int k) {
|
|
|
|
const T* src, int lds, int dim, int k,
|
|
|
|
|
|
|
|
int grid_dim, int num) {
|
|
|
|
__shared__ Pair<T> sh_topk[BlockSize];
|
|
|
|
__shared__ Pair<T> sh_topk[BlockSize];
|
|
|
|
__shared__ int maxid[BlockSize / 2];
|
|
|
|
|
|
|
|
const int tid = threadIdx.x;
|
|
|
|
const int tid = threadIdx.x;
|
|
|
|
const int warp = threadIdx.x / 32;
|
|
|
|
const int warp = threadIdx.x / 32;
|
|
|
|
output += blockIdx.x * output_stride;
|
|
|
|
|
|
|
|
indices += blockIdx.x * k;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Pair<T> topk[MaxLength];
|
|
|
|
const int bid = blockIdx.x;
|
|
|
|
int beam = MaxLength;
|
|
|
|
for (int i = bid; i < num; i += grid_dim) {
|
|
|
|
Pair<T> max;
|
|
|
|
int top_num = k;
|
|
|
|
bool is_empty = false;
|
|
|
|
__shared__ int maxid[BlockSize / 2];
|
|
|
|
bool firststep = true;
|
|
|
|
T* out = output + i * output_stride;
|
|
|
|
|
|
|
|
int64_t* inds = indices + i * k;
|
|
|
|
|
|
|
|
Pair<T> topk[MaxLength];
|
|
|
|
|
|
|
|
int beam = MaxLength;
|
|
|
|
|
|
|
|
Pair<T> max;
|
|
|
|
|
|
|
|
bool is_empty = false;
|
|
|
|
|
|
|
|
bool firststep = true;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < MaxLength; j++) {
|
|
|
|
|
|
|
|
topk[j].set(-INFINITY, -1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
while (top_num) {
|
|
|
|
|
|
|
|
ThreadGetTopK<T, MaxLength, BlockSize>(
|
|
|
|
|
|
|
|
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
|
|
|
|
|
|
|
|
|
|
|
|
for (int k = 0; k < MaxLength; k++) {
|
|
|
|
sh_topk[tid] = topk[0];
|
|
|
|
topk[k].set(-INFINITY, -1);
|
|
|
|
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
|
|
|
|
|
|
|
|
&beam, &top_num, tid, warp);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
while (k) {
|
|
|
|
}
|
|
|
|
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, k,
|
|
|
|
|
|
|
|
src + blockIdx.x * lds, &firststep,
|
|
|
|
inline static int GetDesiredBlockDim(int dim) {
|
|
|
|
&is_empty, &max, dim, tid);
|
|
|
|
if (dim > 128) {
|
|
|
|
|
|
|
|
return 256;
|
|
|
|
sh_topk[tid] = topk[0];
|
|
|
|
} else if (dim > 64) {
|
|
|
|
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
|
|
|
|
return 128;
|
|
|
|
&indices, &beam, &k, tid, warp);
|
|
|
|
} else if (dim > 32) {
|
|
|
|
|
|
|
|
return 64;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
return 32;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
|
|
|
|
|
|
|
|
case (dim): { \
|
|
|
|
|
|
|
|
constexpr auto kBlockDim = (dim); \
|
|
|
|
|
|
|
|
__VA_ARGS__; \
|
|
|
|
|
|
|
|
} break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define FIXED_BLOCK_DIM(...) \
|
|
|
|
|
|
|
|
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
|
|
|
|
|
|
|
|
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
|
|
|
|
|
|
|
|
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
|
|
|
|
|
|
|
|
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
class TopkOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
class TopkOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
@ -298,30 +327,38 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
size_t k = static_cast<int>(ctx.Attr<int>("k"));
|
|
|
|
size_t k = static_cast<int>(ctx.Attr<int>("k"));
|
|
|
|
|
|
|
|
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
const T* input_data = input->data<T>();
|
|
|
|
|
|
|
|
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
// FIXME(typhoonzero): data is always converted to type T?
|
|
|
|
// FIXME(typhoonzero): data is always converted to type T?
|
|
|
|
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
|
|
size_t input_height = input->dims()[0];
|
|
|
|
framework::DDim inputdims = input->dims();
|
|
|
|
size_t input_width = input->dims()[1];
|
|
|
|
const size_t input_height = framework::product(
|
|
|
|
|
|
|
|
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
|
|
|
|
|
|
|
|
const size_t input_width = inputdims[inputdims.size() - 1];
|
|
|
|
|
|
|
|
|
|
|
|
if (k > input_width) k = input_width;
|
|
|
|
if (k > input_width) k = input_width;
|
|
|
|
|
|
|
|
|
|
|
|
// NOTE: pass lds and dim same to input width.
|
|
|
|
// NOTE: pass lds and dim same to input width.
|
|
|
|
// NOTE: old matrix implementation of stride is different to eigen.
|
|
|
|
// NOTE: old matrix implementation of stride is different to eigen.
|
|
|
|
// TODO(typhoonzero): refine this kernel.
|
|
|
|
// TODO(typhoonzero): refine this kernel.
|
|
|
|
dim3 threads(256, 1);
|
|
|
|
const int kMaxHeight = 2048;
|
|
|
|
dim3 grid(input_height, 1);
|
|
|
|
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.cuda_device_context();
|
|
|
|
KeMatrixTopK<T, 5, 256><<<
|
|
|
|
switch (GetDesiredBlockDim(input_width)) {
|
|
|
|
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
FIXED_BLOCK_DIM(
|
|
|
|
ctx.device_context())
|
|
|
|
KeMatrixTopK<T, 5,
|
|
|
|
.stream()>>>(
|
|
|
|
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
|
|
|
|
output_data, output->dims()[1], indices_data, input_data, input_width,
|
|
|
|
output_data, k, indices_data, input_data, input_width,
|
|
|
|
input_width, static_cast<int>(k));
|
|
|
|
input_width, static_cast<int>(k), gridx, input_height));
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
PADDLE_THROW("Error");
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#undef FIXED_BLOCK_DIM_BASE
|
|
|
|
|
|
|
|
#undef FIXED_BLOCK_DIM
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|