make runtime avx act

createGenDocLib
tensor-tang 7 years ago
parent 3462c29940
commit ba943d38e3

@ -299,11 +299,21 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1});
math::VecActivations<T> act_functor;
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
act_gate = act_functor(ctx.Attr<std::string>("gate_activation"));
act_cell = act_functor(ctx.Attr<std::string>("cell_activation"));
act_cand = act_functor(ctx.Attr<std::string>("candidate_activation"));
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
}
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;

Loading…
Cancel
Save