|
|
|
@ -26,8 +26,7 @@ namespace detail {
|
|
|
|
|
|
|
|
|
|
template <class T, class Op>
|
|
|
|
|
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
|
|
|
|
|
int frame_size,
|
|
|
|
|
ActivationType active_node,
|
|
|
|
|
int frame_size, ActivationType active_node,
|
|
|
|
|
ActivationType active_gate,
|
|
|
|
|
ActivationType active_state) {
|
|
|
|
|
T r_value_in;
|
|
|
|
@ -149,8 +148,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
|
|
|
|
|
|
|
|
|
|
template <class T, class Op>
|
|
|
|
|
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
|
|
|
|
|
int frame_size,
|
|
|
|
|
ActivationType active_node,
|
|
|
|
|
int frame_size, ActivationType active_node,
|
|
|
|
|
ActivationType active_gate,
|
|
|
|
|
ActivationType active_state) {
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
@ -281,8 +279,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
|
|
|
|
|
|
|
|
|
|
template <class T, class Op>
|
|
|
|
|
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
|
|
|
|
|
ActivationType active_node,
|
|
|
|
|
ActivationType active_gate,
|
|
|
|
|
ActivationType active_node, ActivationType active_gate,
|
|
|
|
|
ActivationType active_state) {
|
|
|
|
|
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
|
|
|
|
|
avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
|
|
|
|
|