|
|
|
@ -28,10 +28,10 @@ namespace operators {
|
|
|
|
|
#define WARP_SIZE 32
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__inline__ __device__ T warpReduceSum(T val) {
|
|
|
|
|
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
|
|
|
|
|
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
|
|
|
|
|
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
|
|
|
|
|
val += __shfl_xor_sync(FINAL_MASK, val, mask, warpSize);
|
|
|
|
|
val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
|
|
|
|
|
#else
|
|
|
|
|
val += __shfl_xor(val, mask, warpSize);
|
|
|
|
|
#endif
|
|
|
|
@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) {
|
|
|
|
|
|
|
|
|
|
/* Calculate the sum of all elements in a block */
|
|
|
|
|
template <typename T>
|
|
|
|
|
__inline__ __device__ T blockReduceSum(T val) {
|
|
|
|
|
__inline__ __device__ T blockReduceSum(T val, unsigned mask) {
|
|
|
|
|
static __shared__ T shared[WARP_SIZE];
|
|
|
|
|
int lane = threadIdx.x & 0x1f;
|
|
|
|
|
int wid = threadIdx.x >> 5;
|
|
|
|
|
|
|
|
|
|
val = warpReduceSum<T>(val);
|
|
|
|
|
val = warpReduceSum<T>(val, mask);
|
|
|
|
|
|
|
|
|
|
if (lane == 0) shared[wid] = val;
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)(0.0f);
|
|
|
|
|
val = warpReduceSum<T>(val);
|
|
|
|
|
// align block_span to warpSize
|
|
|
|
|
int block_span = (blockDim.x + warpSize - 1) >> 5;
|
|
|
|
|
val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f);
|
|
|
|
|
val = warpReduceSum<T>(val, mask);
|
|
|
|
|
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__inline__ __device__ T warpReduceMax(T val) {
|
|
|
|
|
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
|
|
|
|
|
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
|
|
|
|
|
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
|
|
|
|
|
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, warpSize));
|
|
|
|
|
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
|
|
|
|
|
#else
|
|
|
|
|
val = max(val, __shfl_xor(val, mask, warpSize));
|
|
|
|
|
#endif
|
|
|
|
@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) {
|
|
|
|
|
|
|
|
|
|
/* Calculate the maximum of all elements in a block */
|
|
|
|
|
template <typename T>
|
|
|
|
|
__inline__ __device__ T blockReduceMax(T val) {
|
|
|
|
|
__inline__ __device__ T blockReduceMax(T val, unsigned mask) {
|
|
|
|
|
static __shared__ T shared[WARP_SIZE];
|
|
|
|
|
int lane = threadIdx.x & 0x1f;
|
|
|
|
|
int wid = threadIdx.x >> 5;
|
|
|
|
|
|
|
|
|
|
val = warpReduceMax(val);
|
|
|
|
|
val = warpReduceMax(val, mask);
|
|
|
|
|
|
|
|
|
|
if (lane == 0) shared[wid] = val;
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : -1e10f;
|
|
|
|
|
val = warpReduceMax(val);
|
|
|
|
|
// align block_span to warpSize
|
|
|
|
|
int block_span = (blockDim.x + warpSize - 1) >> 5;
|
|
|
|
|
val = (threadIdx.x < block_span) ? shared[lane] : -1e10f;
|
|
|
|
|
val = warpReduceMax(val, mask);
|
|
|
|
|
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
@ -190,7 +194,8 @@ template <typename T>
|
|
|
|
|
__global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
|
|
|
|
|
const int batch_size,
|
|
|
|
|
const int head_num,
|
|
|
|
|
const int seq_len) {
|
|
|
|
|
const int seq_len,
|
|
|
|
|
const unsigned mask) {
|
|
|
|
|
int seq_id = blockIdx.x % seq_len;
|
|
|
|
|
int qk_offset = blockIdx.x * seq_len;
|
|
|
|
|
int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len;
|
|
|
|
@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
|
|
|
|
|
bias_qk_[threadIdx.x + bias_offset]))
|
|
|
|
|
: 0.0f;
|
|
|
|
|
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
|
|
|
|
|
float max_val = blockReduceMax<float>(tmp);
|
|
|
|
|
|
|
|
|
|
float max_val = blockReduceMax<float>(tmp, mask);
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) s_max = max_val;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
float qk_tmp =
|
|
|
|
|
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
|
|
|
|
|
float sum_val = blockReduceSum<float>(qk_tmp);
|
|
|
|
|
float sum_val = blockReduceSum<float>(qk_tmp, mask);
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
s_sum = sum_val + 1e-6f;
|
|
|
|
@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
|
|
|
|
|
int grid = m;
|
|
|
|
|
int block = k;
|
|
|
|
|
|
|
|
|
|
unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK;
|
|
|
|
|
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
|
|
|
|
|
qk_buf_, bias_qk, batch_size, head_num, seq_len);
|
|
|
|
|
qk_buf_, bias_qk, batch_size, head_num, seq_len, mask);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|