|
|
|
@ -48,32 +48,15 @@ namespace forward {
|
|
|
|
|
namespace avx {
|
|
|
|
|
__m256 Sigmoid(const __m256 a);
|
|
|
|
|
__m256 Tanh(const __m256 a);
|
|
|
|
|
|
|
|
|
|
} // namespace avx
|
|
|
|
|
} // namespace forward
|
|
|
|
|
} // namespace detail
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
|
|
|
|
|
float* ht) {
|
|
|
|
|
namespace act = detail::forward::avx;
|
|
|
|
|
// gates: W_ch, W_ih, W_fh, W_oh
|
|
|
|
|
__m256 c, i, f, o;
|
|
|
|
|
c = _mm256_loadu_ps(gates);
|
|
|
|
|
i = _mm256_loadu_ps(gates + 8);
|
|
|
|
|
f = _mm256_loadu_ps(gates + 16);
|
|
|
|
|
o = _mm256_loadu_ps(gates + 24);
|
|
|
|
|
float* ht);
|
|
|
|
|
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
|
|
|
|
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
|
|
|
|
|
i = _mm256_loadu_ps(ct_1);
|
|
|
|
|
f = _mm256_mul_ps(i, act::Sigmoid(f));
|
|
|
|
|
f = _mm256_add_ps(c, f);
|
|
|
|
|
_mm256_storeu_ps(ct, f);
|
|
|
|
|
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
|
|
|
|
|
_mm256_storeu_ps(ht, o);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|