|
|
|
@ -440,13 +440,11 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"),
|
|
|
|
|
"Cannot find QK");
|
|
|
|
|
auto *out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
|
|
|
|
|
auto *input_d = input->data<T>();
|
|
|
|
|
auto *w_d = w->data<T>();
|
|
|
|
|
auto *bias_d = bias->data<T>();
|
|
|
|
|
auto *bias_qk_d = bias_qk.data<T>();
|
|
|
|
|
auto *output_d = out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
T scale = static_cast<T>(context.Attr<float>("alpha"));
|
|
|
|
|
|
|
|
|
|
int head_number = context.Attr<int>("head_number");
|
|
|
|
@ -463,6 +461,10 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
|
|
|
|
|
int all_head_size = w_dims[2];
|
|
|
|
|
int head_size = all_head_size / head_number;
|
|
|
|
|
|
|
|
|
|
auto *out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
out->Resize({batch, seq_len, all_head_size});
|
|
|
|
|
auto *output_d = out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
// (B*S, hidden)
|
|
|
|
|
const Tensor input_matrix =
|
|
|
|
|
framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
|
|
|
|
|