Fix multihead op bug. (#20783)

The op should handle k=1024

test=develop

Signed-off-by: zhaoyuchen <zhaoyuchen01@baidu.com>
yaoxuefeng
zhaoyuchen2018 6 years ago committed by GitHub
parent 28dd2a58df
commit 6e6eab07e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -134,7 +134,7 @@ MultiHeadMatMul Operator.
This op is used for optimize multi head calculation in ernie model.
Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of H
Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
Both the input `Q` and `K` can carry the LoD (Level of Details) information,

@ -331,7 +331,7 @@ void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx,
auto stream = dev_ctx.stream();
int grid = m;
PADDLE_ENFORCE_LT(k, 1024,
PADDLE_ENFORCE_LE(k, 1024,
"Input head_number * size_per_head should <= 1024");
int block = k <= 1024 ? k : 1024;
add_QKV<T><<<grid, block, 0, stream>>>(Q, K, V, q_buf, k_buf, v_buf, bias_q,

Loading…
Cancel
Save