|
|
|
@ -16,6 +16,9 @@ limitations under the License. */
|
|
|
|
|
#include <string>
|
|
|
|
|
#include "paddle/fluid/operators/math/cpu_vec.h"
|
|
|
|
|
#include "paddle/fluid/platform/cpu_info.h"
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
#include <immintrin.h>
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -35,13 +38,47 @@ void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
|
|
|
|
|
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
|
|
|
|
|
// H_t = act_cell(C_t) * ogated
|
|
|
|
|
T tmp = ct[d] * 2;
|
|
|
|
|
tmp = static_cast<T>(0) - (tmp < min) ? min : ((tmp > max) ? max : tmp);
|
|
|
|
|
tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
|
|
|
|
|
vec_exp<T>(1, &tmp, &tmp);
|
|
|
|
|
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
|
|
|
|
|
ht[d] = tmp * o[d];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
namespace detail {
|
|
|
|
|
namespace forward {
|
|
|
|
|
namespace avx {
|
|
|
|
|
__m256 Sigmoid(const __m256 a);
|
|
|
|
|
__m256 Tanh(const __m256 a);
|
|
|
|
|
} // namespace avx
|
|
|
|
|
} // namespace forward
|
|
|
|
|
} // namespace detail
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
|
|
|
|
|
float* ht) {
|
|
|
|
|
namespace act = detail::forward::avx;
|
|
|
|
|
// gates: W_ch, W_ih, W_fh, W_oh
|
|
|
|
|
__m256 c, i, f, o;
|
|
|
|
|
c = _mm256_loadu_ps(gates);
|
|
|
|
|
i = _mm256_loadu_ps(gates + 8);
|
|
|
|
|
f = _mm256_loadu_ps(gates + 16);
|
|
|
|
|
o = _mm256_loadu_ps(gates + 24);
|
|
|
|
|
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/
|
|
|
|
|
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
|
|
|
|
|
i = _mm256_loadu_ps(ct_1);
|
|
|
|
|
f = _mm256_mul_ps(i, act::Sigmoid(f));
|
|
|
|
|
f = _mm256_add_ps(c, f);
|
|
|
|
|
_mm256_storeu_ps(ct, f);
|
|
|
|
|
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */
|
|
|
|
|
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
|
|
|
|
|
_mm256_storeu_ps(ht, o);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|