|
|
@ -23,9 +23,9 @@ using Tensor = framework::Tensor;
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
struct Pair {
|
|
|
|
struct Pair {
|
|
|
|
__device__ __forceinline__ Pair() {}
|
|
|
|
__device__ __forceinline__ Pair() {}
|
|
|
|
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
|
|
|
|
__device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}
|
|
|
|
|
|
|
|
|
|
|
|
__device__ __forceinline__ void set(T value, int id) {
|
|
|
|
__device__ __forceinline__ void set(T value, int64_t id) {
|
|
|
|
v = value;
|
|
|
|
v = value;
|
|
|
|
id = id;
|
|
|
|
id = id;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -48,7 +48,7 @@ struct Pair {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
T v;
|
|
|
|
T v;
|
|
|
|
int id;
|
|
|
|
int64_t id;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
@ -197,7 +197,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
|
|
|
|
template <typename T, int MaxLength, int BlockSize>
|
|
|
|
template <typename T, int MaxLength, int BlockSize>
|
|
|
|
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
|
|
|
|
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
|
|
|
|
Pair<T> topk[], T** topVal,
|
|
|
|
Pair<T> topk[], T** topVal,
|
|
|
|
int** topIds, int& beam, int& k,
|
|
|
|
int64_t** topIds, int& beam, int& k,
|
|
|
|
const int tid, const int warp) {
|
|
|
|
const int tid, const int warp) {
|
|
|
|
while (true) {
|
|
|
|
while (true) {
|
|
|
|
__syncthreads();
|
|
|
|
__syncthreads();
|
|
|
@ -249,7 +249,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
|
|
|
|
* 4. go to the first setp, until get the topk value.
|
|
|
|
* 4. go to the first setp, until get the topk value.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
template <typename T, int MaxLength, int BlockSize>
|
|
|
|
template <typename T, int MaxLength, int BlockSize>
|
|
|
|
__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
|
|
|
|
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
|
|
|
|
const T* src, int lds, int dim, int k) {
|
|
|
|
const T* src, int lds, int dim, int k) {
|
|
|
|
__shared__ Pair<T> sh_topk[BlockSize];
|
|
|
|
__shared__ Pair<T> sh_topk[BlockSize];
|
|
|
|
__shared__ int maxid[BlockSize / 2];
|
|
|
|
__shared__ int maxid[BlockSize / 2];
|
|
|
@ -293,7 +293,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
// FIXME(typhoonzero): data is always converted to type T?
|
|
|
|
// FIXME(typhoonzero): data is always converted to type T?
|
|
|
|
int* indices_data = indices->mutable_data<int>(ctx.GetPlace());
|
|
|
|
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
|
|
size_t input_height = input->dims()[0];
|
|
|
|
size_t input_height = input->dims()[0];
|
|
|
|
size_t input_width = input->dims()[1];
|
|
|
|
size_t input_width = input->dims()[1];
|
|
|
|