|
|
|
@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INTRIAVX2_FLOAT(block) \
|
|
|
|
|
#define INTRIAVX2_FLOAT(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
CRFDecodeKernelImpl<float, jit::avx2, block>::CRFDecodeKernelImpl( \
|
|
|
|
|
int tag_num) \
|
|
|
|
|
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
|
|
|
|
|
: CRFDecodeKernel<float>() { \
|
|
|
|
|
this->num_ = tag_num; \
|
|
|
|
|
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \
|
|
|
|
|
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void CRFDecodeKernelImpl<float, jit::avx2, block>::Compute( \
|
|
|
|
|
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
|
|
|
|
|
const int seq_len, const float* x, const float* w, float* alpha, \
|
|
|
|
|
int* track) const { \
|
|
|
|
|
INIT_ALPHA(AVX2_FLOAT_BLOCK) \
|
|
|
|
@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
|
|
|
|
|
int j_offset = 0; \
|
|
|
|
|
for (int j = 0; j <= this->end_; ++j) { \
|
|
|
|
|
/* Initialize the variables of maximum score and location.*/ \
|
|
|
|
|
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<T>::max()); \
|
|
|
|
|
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max()); \
|
|
|
|
|
__m512i max_j = _mm512_setzero_si512(); \
|
|
|
|
|
/* Calculate the offset of transition_weights.*/ \
|
|
|
|
|
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
|
|
|
|
@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
|
|
|
|
|
__m512 x_content = \
|
|
|
|
|
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
|
|
|
|
|
max_score = _mm512_add_ps(max_score, x_content); \
|
|
|
|
|
_mm512_storeu_ps(alpha_value + seq_offset + this->tag_num_ + j_offset, \
|
|
|
|
|
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \
|
|
|
|
|
max_score); \
|
|
|
|
|
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
|
|
|
|
|
this->num_ + j_offset), \
|
|
|
|
@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16);
|
|
|
|
|
INTRIAVX_FLOAT(kGT16);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
INTRIAVX2_FLOAT(kEQ8);
|
|
|
|
|
INTRIAVX2_FLOAT(kGT8LT16);
|
|
|
|
|
INTRIAVX2_FLOAT(kEQ16);
|
|
|
|
|
INTRIAVX2_FLOAT(kGT16);
|
|
|
|
|
INTRIAVX2_FLOAT(jit::avx2, kEQ8);
|
|
|
|
|
INTRIAVX2_FLOAT(jit::avx2, kGT8LT16);
|
|
|
|
|
INTRIAVX2_FLOAT(jit::avx2, kEQ16);
|
|
|
|
|
INTRIAVX2_FLOAT(jit::avx2, kGT16);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX512F__
|
|
|
|
|
INTRIAVX2_FLOAT(kEQ8);
|
|
|
|
|
INTRIAVX2_FLOAT(kGT8LT16);
|
|
|
|
|
INTRIAVX2_FLOAT(jit::avx512f, kEQ8);
|
|
|
|
|
INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16);
|
|
|
|
|
INTRIAVX512_FLOAT(kEQ16);
|
|
|
|
|
INTRIAVX512_FLOAT(kGT16);
|
|
|
|
|
#endif
|
|
|
|
|