|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/top_k_op.h"
|
|
|
|
|
#include "paddle/fluid/platform/assert.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -133,71 +134,71 @@ __device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int MaxLength, int BlockSize>
|
|
|
|
|
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
|
|
|
|
|
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
|
|
|
|
|
int beam_size, const T* src,
|
|
|
|
|
bool& firstStep, bool& is_empty,
|
|
|
|
|
Pair<T>& max, int dim,
|
|
|
|
|
bool* firstStep, bool* is_empty,
|
|
|
|
|
Pair<T>* max, int dim,
|
|
|
|
|
const int tid) {
|
|
|
|
|
if (beam > 0) {
|
|
|
|
|
int length = beam < beam_size ? beam : beam_size;
|
|
|
|
|
if (firstStep) {
|
|
|
|
|
firstStep = false;
|
|
|
|
|
if (*beam > 0) {
|
|
|
|
|
int length = (*beam) < beam_size ? *beam : beam_size;
|
|
|
|
|
if (*firstStep) {
|
|
|
|
|
*firstStep = false;
|
|
|
|
|
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
|
|
|
|
|
} else {
|
|
|
|
|
for (int k = 0; k < MaxLength; k++) {
|
|
|
|
|
if (k < MaxLength - beam) {
|
|
|
|
|
topk[k] = topk[k + beam];
|
|
|
|
|
if (k < MaxLength - (*beam)) {
|
|
|
|
|
topk[k] = topk[k + *beam];
|
|
|
|
|
} else {
|
|
|
|
|
topk[k].set(-INFINITY, -1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!is_empty) {
|
|
|
|
|
GetTopK<T, BlockSize>(topk + MaxLength - beam, src, tid, dim, max,
|
|
|
|
|
if (!(*is_empty)) {
|
|
|
|
|
GetTopK<T, BlockSize>(topk + MaxLength - *beam, src, tid, dim, *max,
|
|
|
|
|
length);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
max = topk[MaxLength - 1];
|
|
|
|
|
if (max.v == -1) is_empty = true;
|
|
|
|
|
beam = 0;
|
|
|
|
|
*max = topk[MaxLength - 1];
|
|
|
|
|
if ((*max).v == -1) *is_empty = true;
|
|
|
|
|
*beam = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int MaxLength, int BlockSize>
|
|
|
|
|
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
|
|
|
|
|
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
|
|
|
|
|
int beam_size, const T* val,
|
|
|
|
|
int* col, bool& firstStep,
|
|
|
|
|
bool& is_empty, Pair<T>& max,
|
|
|
|
|
int* col, bool* firstStep,
|
|
|
|
|
bool* is_empty, Pair<T>* max,
|
|
|
|
|
int dim, const int tid) {
|
|
|
|
|
if (beam > 0) {
|
|
|
|
|
int length = beam < beam_size ? beam : beam_size;
|
|
|
|
|
if (firstStep) {
|
|
|
|
|
firstStep = false;
|
|
|
|
|
if (*beam > 0) {
|
|
|
|
|
int length = (*beam) < beam_size ? *beam : beam_size;
|
|
|
|
|
if (*firstStep) {
|
|
|
|
|
*firstStep = false;
|
|
|
|
|
GetTopK<T, BlockSize>(topk, val, col, tid, dim, length);
|
|
|
|
|
} else {
|
|
|
|
|
for (int k = 0; k < MaxLength; k++) {
|
|
|
|
|
if (k < MaxLength - beam) {
|
|
|
|
|
topk[k] = topk[k + beam];
|
|
|
|
|
if (k < MaxLength - *beam) {
|
|
|
|
|
topk[k] = topk[k + *beam];
|
|
|
|
|
} else {
|
|
|
|
|
topk[k].set(-INFINITY, -1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!is_empty) {
|
|
|
|
|
GetTopK<T, BlockSize>(topk + MaxLength - beam, val, col, tid, dim, max,
|
|
|
|
|
if (!(*is_empty)) {
|
|
|
|
|
GetTopK<T, BlockSize>(topk + MaxLength - *beam, val, col, tid, dim, max,
|
|
|
|
|
length);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
max = topk[MaxLength - 1];
|
|
|
|
|
if (max.v == -1) is_empty = true;
|
|
|
|
|
beam = 0;
|
|
|
|
|
*max = topk[MaxLength - 1];
|
|
|
|
|
if ((*max).v == -1) *is_empty = true;
|
|
|
|
|
*beam = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int MaxLength, int BlockSize>
|
|
|
|
|
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
|
|
|
|
|
Pair<T> topk[], T** topVal,
|
|
|
|
|
int64_t** topIds, int& beam, int& k,
|
|
|
|
|
int64_t** topIds, int* beam, int* k,
|
|
|
|
|
const int tid, const int warp) {
|
|
|
|
|
while (true) {
|
|
|
|
|
__syncthreads();
|
|
|
|
@ -225,17 +226,17 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
|
|
|
|
|
(*topVal)++;
|
|
|
|
|
(*topIds)++;
|
|
|
|
|
}
|
|
|
|
|
if (tid == maxid[0]) beam++;
|
|
|
|
|
if (--k == 0) break;
|
|
|
|
|
if (tid == maxid[0]) (*beam)++;
|
|
|
|
|
if (--(*k) == 0) break;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (tid == maxid[0]) {
|
|
|
|
|
if (beam < MaxLength) {
|
|
|
|
|
sh_topk[tid] = topk[beam];
|
|
|
|
|
if (*beam < MaxLength) {
|
|
|
|
|
sh_topk[tid] = topk[*beam];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (maxid[0] / 32 == warp) {
|
|
|
|
|
if (__shfl(beam, (maxid[0]) % 32, 32) == MaxLength) break;
|
|
|
|
|
if (__shfl(*beam, (maxid[0]) % 32, 32) == MaxLength) break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -268,13 +269,13 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
|
|
|
|
|
topk[k].set(-INFINITY, -1);
|
|
|
|
|
}
|
|
|
|
|
while (k) {
|
|
|
|
|
ThreadGetTopK<T, MaxLength, BlockSize>(topk, beam, k,
|
|
|
|
|
src + blockIdx.x * lds, firststep,
|
|
|
|
|
is_empty, max, dim, tid);
|
|
|
|
|
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, k,
|
|
|
|
|
src + blockIdx.x * lds, &firststep,
|
|
|
|
|
&is_empty, &max, dim, tid);
|
|
|
|
|
|
|
|
|
|
sh_topk[tid] = topk[0];
|
|
|
|
|
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
|
|
|
|
|
&indices, beam, k, tid, warp);
|
|
|
|
|
&indices, &beam, &k, tid, warp);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -308,9 +309,9 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
KeMatrixTopK<T, 5, 256><<<
|
|
|
|
|
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
|
|
|
|
|
ctx.device_context())
|
|
|
|
|
.stream()>>>(output_data, output->dims()[1],
|
|
|
|
|
indices_data, input_data,
|
|
|
|
|
input_width, input_width, int(k));
|
|
|
|
|
.stream()>>>(
|
|
|
|
|
output_data, output->dims()[1], indices_data, input_data, input_width,
|
|
|
|
|
input_width, static_cast<int>(k));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|