|
|
|
@ -88,30 +88,33 @@ class MultiHeadMatMulOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3,
|
|
|
|
|
"Multihead input bias qk should be at least 4-D tensor.");
|
|
|
|
|
|
|
|
|
|
int b_size = dim_bias_q.size() - 1;
|
|
|
|
|
int size = dim_q.size() - 1;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_bias_q[b_size], dim_q[size],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"bias_q's last dim size should equal to"
|
|
|
|
|
" q last dim size, but bias_q's size is:%d q is:%d",
|
|
|
|
|
dim_bias_q[b_size], dim_q[size]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_bias_k[b_size], dim_k[size],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"bias_k's last dim size should equal to"
|
|
|
|
|
" k last dim size, but bias_k's size is:%d k is:%d",
|
|
|
|
|
dim_bias_k[b_size], dim_k[size]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_bias_v[b_size], dim_v[size],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"bias_v's last dim size should equal to"
|
|
|
|
|
" v last dim size, but bias_v's size is:%d v is:%d",
|
|
|
|
|
dim_bias_v[b_size], dim_v[size]));
|
|
|
|
|
int b_indx = dim_bias_q.size() - 1;
|
|
|
|
|
int indx = dim_q.size() - 1;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_bias_q[b_indx], dim_q[indx],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"bias_q's last dim size should equal to"
|
|
|
|
|
" q last dim size, but received bias_q's size is:%d q is:%d",
|
|
|
|
|
dim_bias_q[b_indx], dim_q[indx]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_bias_k[b_indx], dim_k[indx],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"bias_k's last dim size should equal to"
|
|
|
|
|
" k last dim size, but received bias_k's size is:%d k is:%d",
|
|
|
|
|
dim_bias_k[b_indx], dim_k[indx]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dim_bias_v[b_indx], dim_v[indx],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"bias_v's last dim size should equal to"
|
|
|
|
|
" v last dim size, but received bias_v's size is:%d v is:%d",
|
|
|
|
|
dim_bias_v[b_indx], dim_v[indx]));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"q should have same batch size"
|
|
|
|
|
"with bias_qk, but q's batch size:%d not equal to "
|
|
|
|
|
"bias_qk's batch size:%d",
|
|
|
|
|
"with bias_qk, but received q's batch size is:%d "
|
|
|
|
|
"bias_qk's batch size is:%d",
|
|
|
|
|
dim_q[0], dim_bias_qk[0]));
|
|
|
|
|
|
|
|
|
|
int head_number = context->Attrs().Get<int>("head_number");
|
|
|
|
|