|
|
|
@ -370,8 +370,10 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
|
|
|
|
|
const int head_num, const float *input, const float *bias,
|
|
|
|
|
float *output, cudaStream_t stream) {
|
|
|
|
|
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH
|
|
|
|
|
int scratch_size = batch * head_num * seq_len * seq_len;
|
|
|
|
|
const dim3 grid(seq_len, batch, 3);
|
|
|
|
|
if (head_size % 4 == 0) {
|
|
|
|
|
// scratch % 4 == 0 to ensure the alignment
|
|
|
|
|
if (head_size % 4 == 0 && scratch_size % 4 == 0) {
|
|
|
|
|
const int h = head_size / 4;
|
|
|
|
|
const float4 *input4 = reinterpret_cast<const float4 *>(input);
|
|
|
|
|
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
|
|
|
|
@ -385,7 +387,7 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
|
|
|
|
|
head_num, head_size, 1024 * 4));
|
|
|
|
|
transpose_qkv_kernel<float4><<<grid, block, 0, stream>>>(h, input4, bias4,
|
|
|
|
|
output4);
|
|
|
|
|
} else if (head_size % 2 == 0) {
|
|
|
|
|
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
|
|
|
|
|
const int h = head_size / 2;
|
|
|
|
|
const float2 *input2 = reinterpret_cast<const float2 *>(input);
|
|
|
|
|
const float2 *bias2 = reinterpret_cast<const float2 *>(bias);
|
|
|
|
|