|
|
|
@ -230,6 +230,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
|
|
|
|
|
auto* cell_out = ctx.Output<LoDTensor>("Cell");
|
|
|
|
|
|
|
|
|
|
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
|
|
|
|
|
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
|
|
|
|
|
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
|
|
|
|
|
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
|
|
|
|
|
if (platform::jit::MayIUse(platform::jit::avx)) {
|
|
|
|
|
math::VecActivations<T, platform::jit::avx> act_functor;
|
|
|
|
|
act_gate = act_functor(act_gate_str);
|
|
|
|
|
act_cell = act_functor(act_cell_str);
|
|
|
|
|
act_cand = act_functor(act_cand_str);
|
|
|
|
|
} else {
|
|
|
|
|
math::VecActivations<T, platform::jit::isa_any> act_functor;
|
|
|
|
|
act_gate = act_functor(act_gate_str);
|
|
|
|
|
act_cell = act_functor(act_cell_str);
|
|
|
|
|
act_cand = act_functor(act_cand_str);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto x_lod = x->lod();
|
|
|
|
|
auto x_dims = x->dims(); // T x M
|
|
|
|
|
auto wh_dims = wh->dims(); // D x 4D
|
|
|
|
@ -263,15 +279,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
prev_cell_data = c0_data + i * D;
|
|
|
|
|
} else {
|
|
|
|
|
// W_ch, W_ih, W_fh, W_oh
|
|
|
|
|
// actgate
|
|
|
|
|
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
|
|
|
|
|
// ch gate
|
|
|
|
|
math::vec_tanh<T>(D, xx_data, xx_data);
|
|
|
|
|
act_gate(D3, xx_data + D, xx_data + D);
|
|
|
|
|
act_cand(D, xx_data, xx_data);
|
|
|
|
|
// cell out= input*tilde
|
|
|
|
|
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
|
|
|
|
|
// hidden out= act_state(cellout) * outgate
|
|
|
|
|
// act state
|
|
|
|
|
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
|
|
|
|
|
act_cell(D, cell_out_data, xx_data + D2);
|
|
|
|
|
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
|
|
|
|
|
|
|
|
|
|
// prev
|
|
|
|
@ -290,10 +303,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
D4);
|
|
|
|
|
|
|
|
|
|
// W_ch, W_ih, W_fh, W_oh
|
|
|
|
|
// actgate
|
|
|
|
|
math::vec_sigmoid<T>(D3, xx_data + D, xx_data + D);
|
|
|
|
|
// ch gate
|
|
|
|
|
math::vec_tanh<T>(D, xx_data, xx_data);
|
|
|
|
|
act_gate(D3, xx_data + D, xx_data + D);
|
|
|
|
|
act_cand(D, xx_data, xx_data);
|
|
|
|
|
|
|
|
|
|
// a = forget * prev_cell
|
|
|
|
|
blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
|
|
|
|
@ -305,8 +316,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
|
|
|
|
|
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
|
|
|
|
|
|
|
|
|
|
// hidden out= act_state(cellout) * outgate
|
|
|
|
|
// act state
|
|
|
|
|
math::vec_tanh<T>(D, cell_out_data, xx_data + D2);
|
|
|
|
|
act_cell(D, cell_out_data, xx_data + D2);
|
|
|
|
|
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
|
|
|
|
|
|
|
|
|
|
// prev
|
|
|
|
|