|
|
|
@ -266,25 +266,24 @@ class FusionGRUKernel : public framework::OpKernel<T> {
|
|
|
|
|
batched_input_data, D3);
|
|
|
|
|
|
|
|
|
|
T* cur_batched_data = batched_input_data;
|
|
|
|
|
T* cur_out_data = batched_out_data;
|
|
|
|
|
T* cur_prev_hidden_data = prev_hidden_data;
|
|
|
|
|
for (int i = 0; i < cur_bs; ++i) {
|
|
|
|
|
act_gate(D2, cur_batched_data, cur_batched_data);
|
|
|
|
|
// rt = rt*ht_1 inplace result
|
|
|
|
|
// TODO(TJ): try to save to cur out data
|
|
|
|
|
// maybe get benifits avoiding cache miss in next gemm
|
|
|
|
|
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D,
|
|
|
|
|
cur_batched_data + D);
|
|
|
|
|
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
|
|
|
|
|
|
|
|
|
|
cur_batched_data += D3;
|
|
|
|
|
cur_prev_hidden_data += D;
|
|
|
|
|
cur_out_data += D;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cur_batched_data = batched_input_data;
|
|
|
|
|
cur_out_data = batched_out_data;
|
|
|
|
|
blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast<T>(1),
|
|
|
|
|
cur_batched_data + D, D3, wh_state_data, D, static_cast<T>(1),
|
|
|
|
|
cur_out_data, D, wh_state_data, D, static_cast<T>(1),
|
|
|
|
|
cur_batched_data + D2, D3);
|
|
|
|
|
|
|
|
|
|
T* cur_out_data = batched_out_data;
|
|
|
|
|
cur_prev_hidden_data = prev_hidden_data;
|
|
|
|
|
for (int i = 0; i < cur_bs; ++i) {
|
|
|
|
|
// ht~ = act_state(...)
|
|
|
|
|