|
|
|
@ -143,30 +143,42 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const T *bias_qk_,
|
|
|
|
|
int qk_offset = blockIdx.x * seq_len;
|
|
|
|
|
assert(blockDim.x % 32 == 0);
|
|
|
|
|
|
|
|
|
|
__shared__ float s_sum, s_max;
|
|
|
|
|
|
|
|
|
|
float qk = threadIdx.x < seq_len
|
|
|
|
|
? static_cast<float>((qk_buf_[threadIdx.x + qk_offset] +
|
|
|
|
|
bias_qk_[threadIdx.x + qk_offset]))
|
|
|
|
|
: 0.0f;
|
|
|
|
|
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
|
|
|
|
|
|
|
|
|
|
float tmp = threadIdx.x < seq_len
|
|
|
|
|
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
|
|
|
|
|
bias_qk_[threadIdx.x + qk_offset])
|
|
|
|
|
: -1e20f;
|
|
|
|
|
float max_val = blockReduceMax<float>(tmp, mask);
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) s_max = max_val;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
float qk_tmp =
|
|
|
|
|
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
|
|
|
|
|
float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
|
|
|
|
|
float sum_val = blockReduceSum<float>(qk_tmp, mask);
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
s_sum = sum_val + 1e-6f;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x < seq_len)
|
|
|
|
|
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
|
|
|
|
|
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void SoftmaxKernelWithEltadd2(T *qk_buf_, const T *bias_qk_,
|
|
|
|
|
const int batch_size,
|
|
|
|
|
const int head_num, const int seq_len,
|
|
|
|
|
const unsigned mask) {
|
|
|
|
|
int qk_offset = blockIdx.x * seq_len;
|
|
|
|
|
int idx = threadIdx.x;
|
|
|
|
|
assert(blockDim.x % 32 == 0);
|
|
|
|
|
|
|
|
|
|
float2 tmp =
|
|
|
|
|
idx < seq_len
|
|
|
|
|
? ToFloat2<T>(qk_buf_[idx + qk_offset] + bias_qk_[idx + qk_offset])
|
|
|
|
|
: make_float2(-1e20f, -1e20f);
|
|
|
|
|
float max_val = blockReduceMax<float>(max(tmp.x, tmp.y), mask);
|
|
|
|
|
float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
|
|
|
|
|
__expf(tmp.y - max_val))
|
|
|
|
|
: make_float2(0.f, 0.f);
|
|
|
|
|
float sum_val = blockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
|
|
|
|
|
|
|
|
|
|
if (idx < seq_len) {
|
|
|
|
|
qk_buf_[idx + qk_offset] =
|
|
|
|
|
FloatsToPair<T>(qk_tmp.x / sum_val, qk_tmp.y / sum_val);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -198,21 +210,28 @@ inline void MatMulWithHeadQK(const platform::CUDADeviceContext &context,
|
|
|
|
|
"seq_len should <= 1024, "
|
|
|
|
|
"but received seq_len is:%d",
|
|
|
|
|
seq_len));
|
|
|
|
|
if (seq_len <= 32)
|
|
|
|
|
block = 32;
|
|
|
|
|
else if (seq_len > 32 && seq_len <= 64)
|
|
|
|
|
block = 64;
|
|
|
|
|
else if (seq_len > 64 && seq_len <= 128)
|
|
|
|
|
block = 128;
|
|
|
|
|
else if (seq_len > 128 && seq_len <= 256)
|
|
|
|
|
block = 256;
|
|
|
|
|
else if (seq_len > 256 && seq_len <= 512)
|
|
|
|
|
block = 512;
|
|
|
|
|
else
|
|
|
|
|
block = 1024;
|
|
|
|
|
|
|
|
|
|
SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>(
|
|
|
|
|
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
|
|
|
|
|
if (seq_len % 2 == 0) {
|
|
|
|
|
block = (seq_len <= 64) ? 32 : ((seq_len + 63) / 64) * 32;
|
|
|
|
|
#ifdef SUPPORTS_CUDA_FP16
|
|
|
|
|
if (std::is_same<T, float>::value) {
|
|
|
|
|
#endif
|
|
|
|
|
SoftmaxKernelWithEltadd2<float2><<<grid, block, 0, stream>>>(
|
|
|
|
|
reinterpret_cast<float2 *>(qk_buf_),
|
|
|
|
|
reinterpret_cast<const float2 *>(bias_qk), batch_size, head_num,
|
|
|
|
|
seq_len / 2, FINAL_MASK);
|
|
|
|
|
#ifdef SUPPORTS_CUDA_FP16
|
|
|
|
|
} else {
|
|
|
|
|
SoftmaxKernelWithEltadd2<__half2><<<grid, block, 0, stream>>>(
|
|
|
|
|
reinterpret_cast<__half2 *>(qk_buf_),
|
|
|
|
|
reinterpret_cast<const __half2 *>(bias_qk), batch_size, head_num,
|
|
|
|
|
seq_len / 2, FINAL_MASK);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
} else {
|
|
|
|
|
block = (seq_len <= 32) ? 32 : ((seq_len + 31) / 32) * 32;
|
|
|
|
|
SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>(
|
|
|
|
|
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|