|
|
@ -25,13 +25,18 @@ limitations under the License. */
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
namespace math {
|
|
|
|
namespace math {
|
|
|
|
#ifdef __AVX__
|
|
|
|
namespace jitkernel {
|
|
|
|
namespace detail {
|
|
|
|
namespace detail {
|
|
|
|
__m256 Exp(__m256 a);
|
|
|
|
#ifdef __AVX__
|
|
|
|
} // namespace detail
|
|
|
|
__m256 ExpAVX(__m256 x);
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
namespace jitkernel {
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
|
|
|
__m256 ExpAVX2(__m256 x);
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace detail
|
|
|
|
|
|
|
|
|
|
|
|
namespace jit = platform::jit;
|
|
|
|
namespace jit = platform::jit;
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
#ifdef __AVX__
|
|
|
@ -43,43 +48,72 @@ class AVXAct {
|
|
|
|
virtual __m256 Compute(__m256 x) const = 0;
|
|
|
|
virtual __m256 Compute(__m256 x) const = 0;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <act_type type>
|
|
|
|
template <act_type type, jit::cpu_isa_t isa>
|
|
|
|
class AVXActImpl : public AVXAct {
|
|
|
|
class AVXActImpl : public AVXAct {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
|
|
|
|
__m256 Compute(__m256 x) const override { PADDLE_THROW("Unkown type!"); }
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
#define AVX_SIGMOID(isa, expisa) \
|
|
|
|
__m256 AVXActImpl<kSigmoid>::Compute(__m256 x) const {
|
|
|
|
template <> \
|
|
|
|
__m256 ones = _mm256_set1_ps(1.0f);
|
|
|
|
__m256 AVXActImpl<kSigmoid, isa>::Compute(__m256 x) const { \
|
|
|
|
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN));
|
|
|
|
__m256 ones = _mm256_set1_ps(1.0f); \
|
|
|
|
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX));
|
|
|
|
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \
|
|
|
|
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x);
|
|
|
|
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \
|
|
|
|
x = detail::Exp(x);
|
|
|
|
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \
|
|
|
|
x = _mm256_add_ps(ones, x);
|
|
|
|
x = expisa(x); \
|
|
|
|
return _mm256_div_ps(ones, x);
|
|
|
|
x = _mm256_add_ps(ones, x); \
|
|
|
|
}
|
|
|
|
return _mm256_div_ps(ones, x); \
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
#define AVX_TANH(isa, expisa) \
|
|
|
|
__m256 AVXActImpl<kTanh>::Compute(__m256 x) const {
|
|
|
|
template <> \
|
|
|
|
__m256 ones = _mm256_set1_ps(1.0f);
|
|
|
|
__m256 AVXActImpl<kTanh, isa>::Compute(__m256 x) const { \
|
|
|
|
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x);
|
|
|
|
__m256 ones = _mm256_set1_ps(1.0f); \
|
|
|
|
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT));
|
|
|
|
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \
|
|
|
|
x = detail::Exp(x);
|
|
|
|
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \
|
|
|
|
x = _mm256_add_ps(ones, x);
|
|
|
|
x = expisa(x); \
|
|
|
|
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x);
|
|
|
|
x = _mm256_add_ps(ones, x); \
|
|
|
|
return _mm256_sub_ps(x, ones);
|
|
|
|
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \
|
|
|
|
}
|
|
|
|
return _mm256_sub_ps(x, ones); \
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
#define AVX_RELU(isa) \
|
|
|
|
__m256 AVXActImpl<kRelu>::Compute(__m256 x) const {
|
|
|
|
template <> \
|
|
|
|
return _mm256_max_ps(x, _mm256_setzero_ps());
|
|
|
|
__m256 AVXActImpl<kRelu, isa>::Compute(__m256 x) const { \
|
|
|
|
}
|
|
|
|
return _mm256_max_ps(x, _mm256_setzero_ps()); \
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define AVX_IDENTITY(isa) \
|
|
|
|
|
|
|
|
template <> \
|
|
|
|
|
|
|
|
__m256 AVXActImpl<kIdentity, isa>::Compute(__m256 x) const { \
|
|
|
|
|
|
|
|
return x; \
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define FOR_EACH_AVX_ISA(macro_) \
|
|
|
|
|
|
|
|
macro_(jit::avx); \
|
|
|
|
|
|
|
|
macro_(jit::avx2); \
|
|
|
|
|
|
|
|
macro_(jit::avx512f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FOR_EACH_AVX_ISA(AVX_RELU);
|
|
|
|
|
|
|
|
FOR_EACH_AVX_ISA(AVX_IDENTITY);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AVX_SIGMOID(jit::avx, detail::ExpAVX);
|
|
|
|
|
|
|
|
AVX_TANH(jit::avx, detail::ExpAVX);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
|
|
|
AVX_SIGMOID(jit::avx2, detail::ExpAVX2);
|
|
|
|
|
|
|
|
AVX_SIGMOID(jit::avx512f, detail::ExpAVX2);
|
|
|
|
|
|
|
|
AVX_TANH(jit::avx2, detail::ExpAVX2);
|
|
|
|
|
|
|
|
AVX_TANH(jit::avx512f, detail::ExpAVX2);
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#undef FOR_EACH_AVX_ISA
|
|
|
|
|
|
|
|
#undef AVX_IDENTITY
|
|
|
|
|
|
|
|
#undef AVX_RELU
|
|
|
|
|
|
|
|
#undef AVX_TANH
|
|
|
|
|
|
|
|
#undef AVX_SIGMOID
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
__m256 AVXActImpl<kIdentity>::Compute(__m256 x) const {
|
|
|
|
|
|
|
|
return x;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
act_cell_d_ = GetActKernel<T>(act_cell, d);
|
|
|
|
act_cell_d_ = GetActKernel<T>(act_cell, d);
|
|
|
|
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
|
|
|
|
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
|
|
|
|
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
|
|
|
|
vadd_d_ = KernelPool::Instance().template Get<VAddKernel<T>>(d);
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
|
|
|
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> {
|
|
|
|
|
|
|
|
if (type == "sigmoid") {
|
|
|
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid>());
|
|
|
|
|
|
|
|
} else if (type == "relu") {
|
|
|
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu>());
|
|
|
|
|
|
|
|
} else if (type == "tanh") {
|
|
|
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh>());
|
|
|
|
|
|
|
|
} else if (type == "identity" || type == "") {
|
|
|
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity>());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PADDLE_THROW("Not support type: %s", type);
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
avx_act_gate_ = GetAVXAct(act_gate);
|
|
|
|
|
|
|
|
avx_act_cand_ = GetAVXAct(act_cand);
|
|
|
|
|
|
|
|
avx_act_cell_ = GetAVXAct(act_cell);
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
|
|
|
|
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
|
|
|
@ -176,6 +193,27 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
|
|
|
|
template <> \
|
|
|
|
|
|
|
|
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
|
|
|
|
|
|
|
|
const std::string& act_gate, const std::string& act_cand, \
|
|
|
|
|
|
|
|
const std::string& act_cell, int d) \
|
|
|
|
|
|
|
|
: LSTMKernel<float>() { \
|
|
|
|
|
|
|
|
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
|
|
|
|
|
|
|
|
if (type == "sigmoid") { \
|
|
|
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
|
|
|
|
|
|
|
|
} else if (type == "relu") { \
|
|
|
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
|
|
|
|
|
|
|
|
} else if (type == "tanh") { \
|
|
|
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
|
|
|
|
|
|
|
|
} else if (type == "identity" || type == "") { \
|
|
|
|
|
|
|
|
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
|
|
|
|
|
|
|
|
} \
|
|
|
|
|
|
|
|
PADDLE_THROW("Not support type: %s", type); \
|
|
|
|
|
|
|
|
}; \
|
|
|
|
|
|
|
|
avx_act_gate_ = GetAVXAct(act_gate); \
|
|
|
|
|
|
|
|
avx_act_cand_ = GetAVXAct(act_cand); \
|
|
|
|
|
|
|
|
avx_act_cell_ = GetAVXAct(act_cell); \
|
|
|
|
|
|
|
|
} \
|
|
|
|
template <> \
|
|
|
|
template <> \
|
|
|
|
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
|
|
|
|
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
|
|
|
|
float* gates, const float* ct_1, float* ct, float* ht, \
|
|
|
|
float* gates, const float* ct_1, float* ct, float* ht, \
|
|
|
@ -195,6 +233,20 @@ class LSTMKernelImpl : public LSTMKernel<T> {
|
|
|
|
/* H_t = act_cell(C_t) * ogated */ \
|
|
|
|
/* H_t = act_cell(C_t) * ogated */ \
|
|
|
|
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
|
|
|
|
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
|
|
|
|
_mm256_storeu_ps(ht, o); \
|
|
|
|
_mm256_storeu_ps(ht, o); \
|
|
|
|
|
|
|
|
} \
|
|
|
|
|
|
|
|
template <> \
|
|
|
|
|
|
|
|
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
|
|
|
|
|
|
|
|
float* gates, float* ct, float* ht, const float* wp_data) const { \
|
|
|
|
|
|
|
|
__m256 c, i, o; \
|
|
|
|
|
|
|
|
c = _mm256_loadu_ps(gates); \
|
|
|
|
|
|
|
|
i = _mm256_loadu_ps(gates + 8); \
|
|
|
|
|
|
|
|
o = _mm256_loadu_ps(gates + 24); \
|
|
|
|
|
|
|
|
/* C_t = igated * cgated*/ \
|
|
|
|
|
|
|
|
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
|
|
|
|
|
|
|
|
_mm256_storeu_ps(ct, c); \
|
|
|
|
|
|
|
|
/* H_t = act_cell(C_t) * ogated */ \
|
|
|
|
|
|
|
|
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
|
|
|
|
|
|
|
|
_mm256_storeu_ps(ht, o); \
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(TJ): optimize keq16
|
|
|
|
// TODO(TJ): optimize keq16
|
|
|
|