|
|
|
|
@ -115,7 +115,18 @@ inline void TransposeQKV(const int batch, const int seq_len,
|
|
|
|
|
const half *input, half *output, cudaStream_t stream) {
|
|
|
|
|
int scratch_size = batch * head_num * seq_len * seq_len;
|
|
|
|
|
const dim3 grid(seq_len, batch, 3);
|
|
|
|
|
if (head_size % 2 == 0 && scratch_size % 2 == 0) {
|
|
|
|
|
if (head_size % 8 == 0 && scratch_size % 8 == 0) {
|
|
|
|
|
int h = head_size / 8;
|
|
|
|
|
const int4 *input4 = reinterpret_cast<const int4 *>(input);
|
|
|
|
|
int4 *output4 = reinterpret_cast<int4 *>(output);
|
|
|
|
|
dim3 block(h, head_num, 1);
|
|
|
|
|
// limit h * head_num to max block size(1024).
|
|
|
|
|
PADDLE_ENFORCE_LE(h * head_num, 1024,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"head_num (%d) * head_size (%d) should <= %d",
|
|
|
|
|
head_num, head_size, 1024 * 8));
|
|
|
|
|
TransposeQkvKernel<int4><<<grid, block, 0, stream>>>(h, input4, output4);
|
|
|
|
|
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
|
|
|
|
|
const int h = head_size / 2;
|
|
|
|
|
const half2 *input2 = reinterpret_cast<const half2 *>(input);
|
|
|
|
|
half2 *output2 = reinterpret_cast<half2 *>(output);
|
|
|
|
|
|