|
|
|
@ -378,11 +378,102 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
|
|
|
|
|
void Compute(const T* x, T* y) const override {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/* VAddRelu JitKernel */
|
|
|
|
|
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
|
|
|
|
|
class VAddReluKernelImpl : public VAddReluKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() { this->num_ = d; }
|
|
|
|
|
void Compute(const T* x, const T* y, T* z) const override {
|
|
|
|
|
for (int i = 0; i < this->num_; ++i) {
|
|
|
|
|
z[i] = x[i] + y[i];
|
|
|
|
|
z[i] = z[i] > 0 ? z[i] : 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define INTRI8_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VAddReluKernelImpl<float, isa, kEQ8>::Compute( \
|
|
|
|
|
const float* x, const float* y, float* z) const { \
|
|
|
|
|
__m256 tmpx = _mm256_loadu_ps(x); \
|
|
|
|
|
__m256 tmpy = _mm256_loadu_ps(y); \
|
|
|
|
|
tmpy = _mm256_add_ps(tmpx, tmpy); \
|
|
|
|
|
tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \
|
|
|
|
|
_mm256_storeu_ps(z, tmpy); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INTRI16_FLOAT(isa) \
|
|
|
|
|
template <> \
|
|
|
|
|
void VAddReluKernelImpl<float, isa, kEQ16>::Compute( \
|
|
|
|
|
const float* x, const float* y, float* z) const { \
|
|
|
|
|
__m256 zeros = _mm256_setzero_ps(); \
|
|
|
|
|
__m256 tmp0 = _mm256_loadu_ps(x); \
|
|
|
|
|
__m256 tmp1 = _mm256_loadu_ps(y); \
|
|
|
|
|
tmp0 = _mm256_add_ps(tmp0, tmp1); \
|
|
|
|
|
tmp0 = _mm256_max_ps(tmp0, zeros); \
|
|
|
|
|
tmp1 = _mm256_loadu_ps(x + 8); \
|
|
|
|
|
__m256 tmp2 = _mm256_loadu_ps(y + 8); \
|
|
|
|
|
tmp1 = _mm256_add_ps(tmp1, tmp2); \
|
|
|
|
|
tmp1 = _mm256_max_ps(tmp1, zeros); \
|
|
|
|
|
_mm256_storeu_ps(z, tmp0); \
|
|
|
|
|
_mm256_storeu_ps(z + 8, tmp1); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INTRI_COMMON_FLOAT(isa, block) \
|
|
|
|
|
template <> \
|
|
|
|
|
VAddReluKernelImpl<float, isa, block>::VAddReluKernelImpl(int d) \
|
|
|
|
|
: VAddReluKernel<float>() { \
|
|
|
|
|
this->num_ = d; \
|
|
|
|
|
this->end_ = d - d % AVX_FLOAT_BLOCK; \
|
|
|
|
|
this->rest_ = d - this->end_; \
|
|
|
|
|
} \
|
|
|
|
|
template <> \
|
|
|
|
|
void VAddReluKernelImpl<float, isa, block>::Compute( \
|
|
|
|
|
const float* x, const float* y, float* z) const { \
|
|
|
|
|
__m256 zeros = _mm256_setzero_ps(); \
|
|
|
|
|
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
|
|
|
|
|
__m256 tmpx = _mm256_loadu_ps(x + i); \
|
|
|
|
|
__m256 tmpy = _mm256_loadu_ps(y + i); \
|
|
|
|
|
tmpy = _mm256_add_ps(tmpx, tmpy); \
|
|
|
|
|
tmpy = _mm256_max_ps(tmpy, zeros); \
|
|
|
|
|
_mm256_storeu_ps(z + i, tmpy); \
|
|
|
|
|
} \
|
|
|
|
|
for (int i = this->end_; i < this->num_; ++i) { \
|
|
|
|
|
z[i] = x[i] + y[i]; \
|
|
|
|
|
z[i] = z[i] > 0 ? z[i] : 0; \
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef __AVX__
|
|
|
|
|
INTRI8_FLOAT(jit::avx);
|
|
|
|
|
INTRI16_FLOAT(jit::avx);
|
|
|
|
|
INTRI_COMMON_FLOAT(jit::avx, kGT8LT16);
|
|
|
|
|
INTRI_COMMON_FLOAT(jit::avx, kGT16);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX2__
|
|
|
|
|
INTRI8_FLOAT(jit::avx2);
|
|
|
|
|
INTRI16_FLOAT(jit::avx2);
|
|
|
|
|
INTRI_COMMON_FLOAT(jit::avx2, kGT8LT16);
|
|
|
|
|
INTRI_COMMON_FLOAT(jit::avx2, kGT16);
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef __AVX512F__
|
|
|
|
|
// TODO(TJ): refine avx512
|
|
|
|
|
INTRI8_FLOAT(jit::avx512f);
|
|
|
|
|
INTRI16_FLOAT(jit::avx512f);
|
|
|
|
|
INTRI_COMMON_FLOAT(jit::avx512f, kGT8LT16);
|
|
|
|
|
INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#undef INTRI8_FLOAT
|
|
|
|
|
#undef INTRI16_FLOAT
|
|
|
|
|
#undef INTRI_COMMON_FLOAT
|
|
|
|
|
|
|
|
|
|
REGISTER_JITKERNEL(vmul, VMulKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vadd, VAddKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vscal, VScalKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vaddb, VAddBiasKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vrelu, VReluKernel);
|
|
|
|
|
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
|
|
|
|
|
REGISTER_JITKERNEL(videntity, VIdentityKernel);
|
|
|
|
|
|
|
|
|
|
} // namespace jitkernel
|
|
|
|
|