Fix top k op GPU code (#5221)

* Fix Type error

* Fix error

* Fix top_k_op GPU code data type
fix-typo
fengjiayi 7 years ago committed by GitHub
parent b9056bb022
commit d3cc7ac304
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,9 +23,9 @@ using Tensor = framework::Tensor;
template <typename T>
struct Pair {
__device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
__device__ __forceinline__ Pair(T value, int64_t id) : v(value), id(id) {}
__device__ __forceinline__ void set(T value, int id) {
__device__ __forceinline__ void set(T value, int64_t id) {
v = value;
id = id;
}
@ -48,7 +48,7 @@ struct Pair {
}
T v;
int id;
int64_t id;
};
template <typename T>
@ -197,7 +197,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
Pair<T> topk[], T** topVal,
int** topIds, int& beam, int& k,
int64_t** topIds, int& beam, int& k,
const int tid, const int warp) {
while (true) {
__syncthreads();
@ -249,7 +249,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 4. go to the first setp, until get the topk value.
*/
template <typename T, int MaxLength, int BlockSize>
__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k) {
__shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
@ -293,7 +293,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T?
int* indices_data = indices->mutable_data<int>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
size_t input_height = input->dims()[0];
size_t input_width = input->dims()[1];

Loading…
Cancel
Save