|
|
|
@ -136,6 +136,21 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <jit::cpu_isa_t isa>
|
|
|
|
|
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
|
|
|
|
|
if (type == "sigmoid") {
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>());
|
|
|
|
|
} else if (type == "relu") {
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>());
|
|
|
|
|
} else if (type == "tanh") {
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>());
|
|
|
|
|
} else if (type == "identity" || type == "") {
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support type: %s", type);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/* LSTM JitKernel */
|
|
|
|
|
template <typename T, jit::cpu_isa_t isa, jit_block>
|
|
|
|
|
class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
@ -192,61 +207,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
|
|
|
|
|
const std::string& act_gate, const std::string& act_cand, \
|
|
|
|
|
const std::string& act_cell, int d) \
|
|
|
|
|
: LSTMKernel<float>() { \
|
|
|
|
|
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
|
|
|
|
|
if (type == "sigmoid") { \
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
|
|
|
|
|
} else if (type == "relu") { \
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
|
|
|
|
|
} else if (type == "tanh") { \
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
|
|
|
|
|
} else if (type == "identity" || type == "") { \
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
|
|
|
|
|
} \
|
|
|
|
|
PADDLE_THROW("Not support type: %s", type); \
|
|
|
|
|
}; \
|
|
|
|
|
avx_act_gate_ = GetAVXAct(act_gate); \
|
|
|
|
|
avx_act_cand_ = GetAVXAct(act_cand); \
|
|
|
|
|
avx_act_cell_ = GetAVXAct(act_cell); \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
|
|
|
|
|
float* gates, const float* ct_1, float* ct, float* ht, \
|
|
|
|
|
const float* wp_data, float* checked) const { \
|
|
|
|
|
/* gates: W_ch, W_ih, W_fh, W_oh */ \
|
|
|
|
|
__m256 c, i, f, o; \
|
|
|
|
|
c = _mm256_loadu_ps(gates); \
|
|
|
|
|
i = _mm256_loadu_ps(gates + 8); \
|
|
|
|
|
f = _mm256_loadu_ps(gates + 16); \
|
|
|
|
|
o = _mm256_loadu_ps(gates + 24); \
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
|
|
|
|
|
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
|
|
|
|
|
i = _mm256_loadu_ps(ct_1); \
|
|
|
|
|
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
|
|
|
|
|
f = _mm256_add_ps(c, f); \
|
|
|
|
|
_mm256_storeu_ps(ct, f); \
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */ \
|
|
|
|
|
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
|
|
|
|
|
_mm256_storeu_ps(ht, o); \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
|
|
|
|
|
float* gates, float* ct, float* ht, const float* wp_data) const { \
|
|
|
|
|
__m256 c, i, o; \
|
|
|
|
|
c = _mm256_loadu_ps(gates); \
|
|
|
|
|
i = _mm256_loadu_ps(gates + 8); \
|
|
|
|
|
o = _mm256_loadu_ps(gates + 24); \
|
|
|
|
|
/* C_t = igated * cgated*/ \
|
|
|
|
|
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
|
|
|
|
|
_mm256_storeu_ps(ct, c); \
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */ \
|
|
|
|
|
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
|
|
|
|
|
_mm256_storeu_ps(ht, o); \
|
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
|
|
|
|
|
const std::string& act_gate, const std::string& act_cand, \
|
|
|
|
|
const std::string& act_cell, int d) \
|
|
|
|
|
: LSTMKernel<float>() { \
|
|
|
|
|
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
|
|
|
|
|
avx_act_cand_ = GetAVXAct<isa>(act_cand); \
|
|
|
|
|
avx_act_cell_ = GetAVXAct<isa>(act_cell); \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
|
|
|
|
|
float* gates, const float* ct_1, float* ct, float* ht, \
|
|
|
|
|
const float* wp_data, float* checked) const { \
|
|
|
|
|
/* gates: W_ch, W_ih, W_fh, W_oh */ \
|
|
|
|
|
__m256 c, i, f, o; \
|
|
|
|
|
c = _mm256_loadu_ps(gates); \
|
|
|
|
|
i = _mm256_loadu_ps(gates + 8); \
|
|
|
|
|
f = _mm256_loadu_ps(gates + 16); \
|
|
|
|
|
o = _mm256_loadu_ps(gates + 24); \
|
|
|
|
|
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \
|
|
|
|
|
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
|
|
|
|
|
i = _mm256_loadu_ps(ct_1); \
|
|
|
|
|
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
|
|
|
|
|
f = _mm256_add_ps(c, f); \
|
|
|
|
|
_mm256_storeu_ps(ct, f); \
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */ \
|
|
|
|
|
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
|
|
|
|
|
_mm256_storeu_ps(ht, o); \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
|
|
|
|
|
float* gates, float* ct, float* ht, const float* wp_data) const { \
|
|
|
|
|
__m256 c, i, o; \
|
|
|
|
|
c = _mm256_loadu_ps(gates); \
|
|
|
|
|
i = _mm256_loadu_ps(gates + 8); \
|
|
|
|
|
o = _mm256_loadu_ps(gates + 24); \
|
|
|
|
|
/* C_t = igated * cgated*/ \
|
|
|
|
|
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
|
|
|
|
|
_mm256_storeu_ps(ct, c); \
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */ \
|
|
|
|
|
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
|
|
|
|
|
_mm256_storeu_ps(ht, o); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): optimize keq16
|
|
|
|
@ -375,6 +378,7 @@ class GRUKernelImpl : public GRUKernel<T> {
|
|
|
|
|
act_state_d_->Compute(gates + d2_, gates + d2_);
|
|
|
|
|
vmul_d_->Compute(gates, gates + d2_, ht);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
|
|
|
|
|
// W: {W_update, W_reset; W_state}
|
|
|
|
|
act_gate_d2_->Compute(gates, gates);
|
|
|
|
@ -394,8 +398,65 @@ class GRUKernelImpl : public GRUKernel<T> {
|
|
|
|
|
int d_, d2_;
|
|
|
|
|
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
|
|
|
|
|
std::shared_ptr<const VMulKernel<T>> vmul_d_;
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
|
|
|
|
|
#endif
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
|
|
|
|
|
const std::string& act_gate, const std::string& act_state, int d) \
|
|
|
|
|
: GRUKernel<float>() { \
|
|
|
|
|
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
|
|
|
|
|
avx_act_state_ = GetAVXAct<isa>(act_state); \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
|
|
|
|
|
const { \
|
|
|
|
|
__m256 u, s; \
|
|
|
|
|
/* W: {W_update, W_reset; W_state} */ \
|
|
|
|
|
u = _mm256_loadu_ps(gates); \
|
|
|
|
|
s = _mm256_loadu_ps(gates + 16); \
|
|
|
|
|
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
|
|
|
|
|
_mm256_storeu_ps(ht, s); \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
|
|
|
|
|
float* gates, const float* ht_1, float* ht) const { \
|
|
|
|
|
/* not exactly equal the any implementation */ \
|
|
|
|
|
__m256 r, ht0; \
|
|
|
|
|
r = _mm256_loadu_ps(gates + 8); \
|
|
|
|
|
ht0 = _mm256_loadu_ps(ht_1); \
|
|
|
|
|
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
|
|
|
|
|
_mm256_storeu_ps(ht, r); \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
|
|
|
|
|
float* gates, const float* ht_1, float* ht) const { \
|
|
|
|
|
/* not exactly equal the any implementation */ \
|
|
|
|
|
__m256 u, s, ht0; \
|
|
|
|
|
u = _mm256_loadu_ps(gates); \
|
|
|
|
|
s = _mm256_loadu_ps(gates + 16); \
|
|
|
|
|
ht0 = _mm256_loadu_ps(ht_1); \
|
|
|
|
|
u = avx_act_gate_->Compute(u); \
|
|
|
|
|
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
|
|
|
|
|
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
|
|
|
|
|
u = _mm256_mul_ps(u, ht0); \
|
|
|
|
|
u = _mm256_add_ps(s, u); \
|
|
|
|
|
_mm256_storeu_ps(ht, u); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
INTRI8_FLOAT(jit::avx);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
INTRI8_FLOAT(jit::avx2);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX512F__
|
|
|
|
|
INTRI8_FLOAT(jit::avx512f);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
|
|
|
|
|
template <> \
|
|
|
|
|
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
|
|
|
|
@ -412,6 +473,7 @@ class GRUKernelImpl : public GRUKernel<T> {
|
|
|
|
|
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
|
|
|
|
|
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
|
|
|
|
|
|
|
|
|
|
#undef INTRI8_FLOAT
|
|
|
|
|
#undef JITKERNEL_NEW_GRU_IMPL
|
|
|
|
|
#undef JITKERNEL_KEY_GRU
|
|
|
|
|
#undef JITKERNEL_DECLARE_GRU
|
|
|
|
|