|
|
|
@ -299,11 +299,21 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
|
|
|
|
|
fc_out->Resize({max_seq_len, 1});
|
|
|
|
|
|
|
|
|
|
math::VecActivations<T> act_functor;
|
|
|
|
|
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
|
|
|
|
|
act_gate = act_functor(ctx.Attr<std::string>("gate_activation"));
|
|
|
|
|
act_cell = act_functor(ctx.Attr<std::string>("cell_activation"));
|
|
|
|
|
act_cand = act_functor(ctx.Attr<std::string>("candidate_activation"));
|
|
|
|
|
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
|
|
|
|
|
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
|
|
|
|
|
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
|
|
|
|
|
if (platform::jit::MayIUse(platform::jit::avx)) {
|
|
|
|
|
math::VecActivations<T, platform::jit::avx> act_functor;
|
|
|
|
|
act_gate = act_functor(act_gate_str);
|
|
|
|
|
act_cell = act_functor(act_cell_str);
|
|
|
|
|
act_cand = act_functor(act_cand_str);
|
|
|
|
|
} else {
|
|
|
|
|
math::VecActivations<T, platform::jit::isa_any> act_functor;
|
|
|
|
|
act_gate = act_functor(act_gate_str);
|
|
|
|
|
act_cell = act_functor(act_cell_str);
|
|
|
|
|
act_cand = act_functor(act_cand_str);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const T* x_data = x->data<T>();
|
|
|
|
|
const T* h0_data = h0 ? h0->data<T>() : NULL;
|
|
|
|
|