fix crf decode avx512

fix_recordio_link
tensor-tang 6 years ago
parent 21487d78bf
commit 64d5b4385e

@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
} \
}
#define INTRIAVX2_FLOAT(block) \
#define INTRIAVX2_FLOAT(isa, block) \
template <> \
CRFDecodeKernelImpl<float, jit::avx2, block>::CRFDecodeKernelImpl( \
int tag_num) \
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx2, block>::Compute( \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(AVX2_FLOAT_BLOCK) \
@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/ \
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<T>::max()); \
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max()); \
__m512i max_j = _mm512_setzero_si512(); \
/* Calculate the offset of transition_weights.*/ \
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
__m512 x_content = \
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm512_add_ps(max_score, x_content); \
_mm512_storeu_ps(alpha_value + seq_offset + this->tag_num_ + j_offset, \
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \
max_score); \
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
this->num_ + j_offset), \
@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16);
INTRIAVX_FLOAT(kGT16);
#endif
#ifdef __AVX2__
INTRIAVX2_FLOAT(kEQ8);
INTRIAVX2_FLOAT(kGT8LT16);
INTRIAVX2_FLOAT(kEQ16);
INTRIAVX2_FLOAT(kGT16);
INTRIAVX2_FLOAT(jit::avx2, kEQ8);
INTRIAVX2_FLOAT(jit::avx2, kGT8LT16);
INTRIAVX2_FLOAT(jit::avx2, kEQ16);
INTRIAVX2_FLOAT(jit::avx2, kGT16);
#endif
#ifdef __AVX512F__
INTRIAVX2_FLOAT(kEQ8);
INTRIAVX2_FLOAT(kGT8LT16);
INTRIAVX2_FLOAT(jit::avx512f, kEQ8);
INTRIAVX2_FLOAT(jit::avx512f, kGT8LT16);
INTRIAVX512_FLOAT(kEQ16);
INTRIAVX512_FLOAT(kGT16);
#endif

Loading…
Cancel
Save