@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/assert.h"
namespace paddle {
namespace paddle {
@ -133,71 +134,71 @@ __device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
}
}
template <typename T, int MaxLength, int BlockSize>
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
int beam_size, const T* src,
int beam_size, const T* src,
bool& firstStep, bool& is_empty,
bool* firstStep, bool* is_empty,
Pair<T>& max, int dim,
Pair<T>* max, int dim,
const int tid) {
const int tid) {
if (beam > 0) {
if (* beam > 0) {
int length = beam < beam_size ? beam : beam_size;
int length = (* beam) < beam_size ? * beam : beam_size;
if (firstStep) {
if (* firstStep) {
firstStep = false;
* firstStep = false;
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
} else {
} else {
for (int k = 0; k < MaxLength; k++) {
for (int k = 0; k < MaxLength; k++) {
if (k < MaxLength - beam) {
if (k < MaxLength - (* beam) ) {
topk[k] = topk[k + beam];
topk[k] = topk[k + * beam];
} else {
} else {
topk[k].set(-INFINITY, -1);
topk[k].set(-INFINITY, -1);
}
}
}
}
if (!is_empty) {
if (!(* is_empty) ) {
GetTopK<T, BlockSize>(topk + MaxLength - beam, src, tid, dim, max,
GetTopK<T, BlockSize>(topk + MaxLength - * beam, src, tid, dim, * max,
length);
length);
}
}
}
}
max = topk[MaxLength - 1];
* max = topk[MaxLength - 1];
if (max.v == -1) is_empty = true;
if ((* max) .v == -1) * is_empty = true;
beam = 0;
* beam = 0;
}
}
}
}
template <typename T, int MaxLength, int BlockSize>
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
int beam_size, const T* val,
int beam_size, const T* val,
int* col, bool& firstStep,
int* col, bool* firstStep,
bool& is_empty, Pair<T>& max,
bool* is_empty, Pair<T>* max,
int dim, const int tid) {
int dim, const int tid) {
if (beam > 0) {
if (* beam > 0) {
int length = beam < beam_size ? beam : beam_size;
int length = (* beam) < beam_size ? * beam : beam_size;
if (firstStep) {
if (* firstStep) {
firstStep = false;
* firstStep = false;
GetTopK<T, BlockSize>(topk, val, col, tid, dim, length);
GetTopK<T, BlockSize>(topk, val, col, tid, dim, length);
} else {
} else {
for (int k = 0; k < MaxLength; k++) {
for (int k = 0; k < MaxLength; k++) {
if (k < MaxLength - beam) {
if (k < MaxLength - * beam) {
topk[k] = topk[k + beam];
topk[k] = topk[k + * beam];
} else {
} else {
topk[k].set(-INFINITY, -1);
topk[k].set(-INFINITY, -1);
}
}
}
}
if (!is_empty) {
if (!(* is_empty) ) {
GetTopK<T, BlockSize>(topk + MaxLength - beam, val, col, tid, dim, max,
GetTopK<T, BlockSize>(topk + MaxLength - * beam, val, col, tid, dim, max,
length);
length);
}
}
}
}
max = topk[MaxLength - 1];
* max = topk[MaxLength - 1];
if (max.v == -1) is_empty = true;
if ((* max) .v == -1) * is_empty = true;
beam = 0;
* beam = 0;
}
}
}
}
template <typename T, int MaxLength, int BlockSize>
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
Pair<T> topk[], T** topVal,
Pair<T> topk[], T** topVal,
int64_t** topIds, int& beam, int& k,
int64_t** topIds, int* beam, int* k,
const int tid, const int warp) {
const int tid, const int warp) {
while (true) {
while (true) {
__syncthreads();
__syncthreads();
@ -225,17 +226,17 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
(*topVal)++;
(*topVal)++;
(*topIds)++;
(*topIds)++;
}
}
if (tid == maxid[0]) beam++;
if (tid == maxid[0]) (* beam) ++;
if (--k == 0) break;
if (--(* k) == 0) break;
__syncthreads();
__syncthreads();
if (tid == maxid[0]) {
if (tid == maxid[0]) {
if (beam < MaxLength) {
if (* beam < MaxLength) {
sh_topk[tid] = topk[beam];
sh_topk[tid] = topk[* beam];
}
}
}
}
if (maxid[0] / 32 == warp) {
if (maxid[0] / 32 == warp) {
if (__shfl(beam, (maxid[0]) % 32, 32) == MaxLength) break;
if (__shfl(* beam, (maxid[0]) % 32, 32) == MaxLength) break;
}
}
}
}
}
}
@ -268,13 +269,13 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
topk[k].set(-INFINITY, -1);
topk[k].set(-INFINITY, -1);
}
}
while (k) {
while (k) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk, beam, k,
ThreadGetTopK<T, MaxLength, BlockSize>(topk, & beam, k,
src + blockIdx.x * lds, firststep,
src + blockIdx.x * lds, & firststep,
is_empty, max, dim, tid);
& is_empty, & max, dim, tid);
sh_topk[tid] = topk[0];
sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
&indices, beam, k, tid, warp);
&indices, & beam, & k, tid, warp);
}
}
}
}
@ -308,9 +309,9 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
KeMatrixTopK<T, 5, 256><<<
KeMatrixTopK<T, 5, 256><<<
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
ctx.device_context())
.stream()>>>(output_data, output->dims()[1],
.stream()>>>(
indices_data, input_data ,
output_data, output->dims()[1], indices_data, input_data, input_width ,
input_width, input_width, int (k));
input_width, static_cast<int> (k));
}
}
};
};