|
|
@ -20,10 +20,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
int Fp32Relu(const float *src, int length, float *dst) {
|
|
|
|
int Fp32Relu(const float *src, int length, float *dst) {
|
|
|
|
int i = 0;
|
|
|
|
int i = 0;
|
|
|
|
#ifdef ENABLE_ARM
|
|
|
|
#if defined(ENABLE_AVX)
|
|
|
|
float32x4_t zero_4 = vdupq_n_f32(0.0f);
|
|
|
|
MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
|
|
|
|
|
|
|
|
for (; i < length - 8; i += 8) {
|
|
|
|
|
|
|
|
MS_ST256_F32(dst + i, MS_MAX256_F32(MS_LD256_F32(src + i), zero_8));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
|
|
|
|
|
|
|
|
MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
|
|
|
|
for (; i < length - 4; i += 4) {
|
|
|
|
for (; i < length - 4; i += 4) {
|
|
|
|
vst1q_f32(dst + i, vmaxq_f32(vld1q_f32(src + i), zero_4));
|
|
|
|
MS_STQ_F32(dst + i, MS_MAXQ_F32(MS_LDQ_F32(src + i), zero));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
for (; i < length; ++i) {
|
|
|
|
for (; i < length; ++i) {
|
|
|
@ -34,13 +41,24 @@ int Fp32Relu(const float *src, int length, float *dst) {
|
|
|
|
|
|
|
|
|
|
|
|
int Fp32Relu6(const float *src, int length, float *dst) {
|
|
|
|
int Fp32Relu6(const float *src, int length, float *dst) {
|
|
|
|
int i = 0;
|
|
|
|
int i = 0;
|
|
|
|
#ifdef ENABLE_ARM
|
|
|
|
|
|
|
|
float32x4_t zero_4 = vdupq_n_f32(0.0f);
|
|
|
|
#if defined(ENABLE_AVX)
|
|
|
|
float32x4_t six_4 = vdupq_n_f32(6.0f);
|
|
|
|
MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
|
|
|
|
|
|
|
|
MS_FLOAT32X8 six_8 = MS_MOV256_F32(6.0f);
|
|
|
|
|
|
|
|
for (; i < length - 8; i += 8) {
|
|
|
|
|
|
|
|
MS_FLOAT32X8 dst_tmp = MS_MAX256_F32(MS_LD256_F32(src + i), zero_8);
|
|
|
|
|
|
|
|
dst_tmp = MS_MIN256_F32(dst_tmp, six_8);
|
|
|
|
|
|
|
|
MS_ST256_F32(dst + i, dst_tmp);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
|
|
|
|
|
|
|
MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
|
|
|
|
|
|
|
|
MS_FLOAT32X4 six = MS_MOVQ_F32(6.0f);
|
|
|
|
for (; i < length - 4; i += 4) {
|
|
|
|
for (; i < length - 4; i += 4) {
|
|
|
|
float32x4_t dst_4 = vmaxq_f32(vld1q_f32(src + i), zero_4);
|
|
|
|
MS_FLOAT32X4 dst_tmp = MS_MAXQ_F32(MS_LDQ_F32(src + i), zero);
|
|
|
|
dst_4 = vminq_f32(dst_4, six_4);
|
|
|
|
dst_tmp = MS_MINQ_F32(dst_tmp, six);
|
|
|
|
vst1q_f32(dst + i, dst_4);
|
|
|
|
MS_STQ_F32(dst + i, dst_tmp);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
for (; i < length; ++i) {
|
|
|
|
for (; i < length; ++i) {
|
|
|
@ -55,14 +73,21 @@ int Fp32Relu6(const float *src, int length, float *dst) {
|
|
|
|
|
|
|
|
|
|
|
|
int LRelu(const float *src, int length, float *dst, float alpha) {
|
|
|
|
int LRelu(const float *src, int length, float *dst, float alpha) {
|
|
|
|
int i = 0;
|
|
|
|
int i = 0;
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
#if defined(ENABLE_AVX)
|
|
|
|
float32x4_t alpha_4 = vdupq_n_f32(alpha);
|
|
|
|
for (; i < length - 8; i += 8) {
|
|
|
|
|
|
|
|
MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i);
|
|
|
|
|
|
|
|
MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha);
|
|
|
|
|
|
|
|
MS_FLOAT32X8 mask = MS_CMP256_PS(src_tmp, MS_MOV256_F32(0.0f), 30);
|
|
|
|
|
|
|
|
MS_ST256_F32(dst + i, MS_BLEND256_PS(mul_tmp, src_tmp, mask));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
|
|
|
for (; i < length - 4; i += 4) {
|
|
|
|
for (; i < length - 4; i += 4) {
|
|
|
|
float32x4_t src_4 = vld1q_f32(src + i);
|
|
|
|
MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i);
|
|
|
|
float32x4_t mul_4 = vmulq_f32(src_4, alpha_4);
|
|
|
|
MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha);
|
|
|
|
uint32x4_t flag = vclezq_f32(src_4);
|
|
|
|
MS_FLOAT32X4 mask = MS_CMPGTQ_PS(src_tmp, MS_MOVQ_F32(0.0f));
|
|
|
|
float32x4_t dst_4 = vbslq_f32(flag, mul_4, src_4);
|
|
|
|
MS_STQ_F32(dst + i, MS_BLENDQ_PS(mul_tmp, src_tmp, mask));
|
|
|
|
vst1q_f32(dst + i, dst_4);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
for (; i < length; ++i) {
|
|
|
|
for (; i < length; ++i) {
|
|
|
@ -73,11 +98,18 @@ int LRelu(const float *src, int length, float *dst, float alpha) {
|
|
|
|
|
|
|
|
|
|
|
|
int Sigmoid(const float *src, int length, float *dst) {
|
|
|
|
int Sigmoid(const float *src, int length, float *dst) {
|
|
|
|
int i = 0;
|
|
|
|
int i = 0;
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
#if defined(ENABLE_AVX)
|
|
|
|
int count = (length / C4NUM) * C4NUM;
|
|
|
|
for (; i < length - 8; i += 8) {
|
|
|
|
for (; i < count; i += C4NUM) {
|
|
|
|
simd_exp_avx(-(MS_LD256_F32(src + i)), dst + i);
|
|
|
|
simd_exp(vnegq_f32(vld1q_f32(src + i)), dst + i);
|
|
|
|
MS_ST256_F32(dst + i,
|
|
|
|
vst1q_f32(dst + i, vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), vld1q_f32(dst + i))));
|
|
|
|
MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i))));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(ENABLE_ARM64) || defined(ENABLE_SSE)
|
|
|
|
|
|
|
|
for (; i < length - 4; i += 4) {
|
|
|
|
|
|
|
|
simd_exp(-(MS_LDQ_F32(src + i)), dst + i);
|
|
|
|
|
|
|
|
MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i))));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
for (; i < length; ++i) {
|
|
|
|
for (; i < length; ++i) {
|
|
|
@ -102,26 +134,40 @@ float TanhOpt(float src) {
|
|
|
|
|
|
|
|
|
|
|
|
int Tanh(const float *src, int length, float *dst) {
|
|
|
|
int Tanh(const float *src, int length, float *dst) {
|
|
|
|
int i = 0;
|
|
|
|
int i = 0;
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
#if defined(ENABLE_ARM) || defined(ENABLE_SSE) || defined(ENABLE_AVX)
|
|
|
|
static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f},
|
|
|
|
const int cnt = 6;
|
|
|
|
{17325.0f, 17325.0f, 17325.0f, 17325.0f},
|
|
|
|
float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
|
|
|
|
{135135.0f, 135135.0f, 135135.0f, 135135.0f},
|
|
|
|
#endif
|
|
|
|
{28.0f, 28.0f, 28.0f, 28.0f},
|
|
|
|
|
|
|
|
{3150.0f, 3150.0f, 3150.0f, 3150.0f},
|
|
|
|
#if defined(ENABLE_AVX)
|
|
|
|
{62370.0f, 62370.0f, 62370.0f, 62370.0f}};
|
|
|
|
MS_FLOAT32X8 neg_one_8 = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
|
|
|
|
float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f};
|
|
|
|
MS_FLOAT32X8 pos_one_8 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
|
|
|
float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f};
|
|
|
|
MS_FLOAT32X8 param256[6];
|
|
|
|
int count = (length / C4NUM) * C4NUM;
|
|
|
|
for (int j = 0; j < cnt; ++j) {
|
|
|
|
for (; i < count; i += C4NUM) {
|
|
|
|
param256[j] = MS_MOV256_F32(data[j]);
|
|
|
|
float32x4_t input = vld1q_f32(src + i);
|
|
|
|
}
|
|
|
|
float32x4_t square = vmulq_f32(input, input);
|
|
|
|
for (; i < length - 8; i += 8) {
|
|
|
|
float32x4_t a = vmulq_f32(
|
|
|
|
MS_FLOAT32X8 input = MS_LD256_F32(src + i);
|
|
|
|
vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]),
|
|
|
|
MS_FLOAT32X8 square = input * input;
|
|
|
|
input);
|
|
|
|
MS_FLOAT32X8 a = (((square + param256[0]) * square + param256[1]) * square + param256[2]) * input;
|
|
|
|
float32x4_t b = vaddq_f32(
|
|
|
|
MS_FLOAT32X8 b = ((param256[3] * square + param256[4]) * square + param256[5]) * square + param256[2];
|
|
|
|
vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square),
|
|
|
|
MS_ST256_F32(dst + i, MS_MIN256_F32(MS_MAX256_F32(a / b, neg_one_8), pos_one_8));
|
|
|
|
paramv[2]);
|
|
|
|
}
|
|
|
|
vst1q_f32(dst + i, vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one));
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
|
|
|
|
|
|
|
MS_FLOAT32X4 param[6];
|
|
|
|
|
|
|
|
MS_FLOAT32X4 neg_one = {-1.0f, -1.0f, -1.0f, -1.0f};
|
|
|
|
|
|
|
|
MS_FLOAT32X4 pos_one = {1.0f, 1.0f, 1.0f, 1.0f};
|
|
|
|
|
|
|
|
for (int j = 0; j < cnt; ++j) {
|
|
|
|
|
|
|
|
param[j] = MS_MOVQ_F32(data[j]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for (; i < length - 4; i += 4) {
|
|
|
|
|
|
|
|
MS_FLOAT32X4 input = MS_LDQ_F32(src + i);
|
|
|
|
|
|
|
|
MS_FLOAT32X4 square = input * input;
|
|
|
|
|
|
|
|
MS_FLOAT32X4 a = (((square + param[0]) * square + param[1]) * square + param[2]) * input;
|
|
|
|
|
|
|
|
MS_FLOAT32X4 b = ((param[3] * square + param[4]) * square + param[5]) * square + param[2];
|
|
|
|
|
|
|
|
MS_STQ_F32(dst + i, MS_MINQ_F32(MS_MAXQ_F32(a / b, neg_one), pos_one));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
for (; i < length; ++i) {
|
|
|
|
for (; i < length; ++i) {
|
|
|
@ -142,12 +188,21 @@ int Swish(const float *src, int length, float *dst) {
|
|
|
|
return NNACL_ERR;
|
|
|
|
return NNACL_ERR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
int index = 0;
|
|
|
|
int index = 0;
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
#if defined(ENABLE_AVX)
|
|
|
|
for (; index <= length - C4NUM; index += C4NUM) {
|
|
|
|
for (; index <= length - 8; index += 8) {
|
|
|
|
float32x4_t src_value = vld1q_f32(src + index);
|
|
|
|
MS_FLOAT32X8 src_value = MS_LD256_F32(src + index);
|
|
|
|
float32x4_t sigmoid_value = vld1q_f32(dst + index);
|
|
|
|
MS_FLOAT32X8 sigmoid_value = MS_LD256_F32(dst + index);
|
|
|
|
float32x4_t result = vmulq_f32(src_value, sigmoid_value);
|
|
|
|
MS_FLOAT32X8 result = MS_MUL256_F32(src_value, sigmoid_value);
|
|
|
|
vst1q_f32(dst + index, result);
|
|
|
|
MS_ST256_F32(dst + index, result);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
|
|
|
|
|
|
|
for (; index <= length - 4; index += 4) {
|
|
|
|
|
|
|
|
MS_FLOAT32X4 src_value = MS_LDQ_F32(src + index);
|
|
|
|
|
|
|
|
MS_FLOAT32X4 sigmoid_value = MS_LDQ_F32(dst + index);
|
|
|
|
|
|
|
|
MS_FLOAT32X4 result = MS_MULQ_F32(src_value, sigmoid_value);
|
|
|
|
|
|
|
|
MS_STQ_F32(dst + index, result);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
for (; index < length; index++) {
|
|
|
|
for (; index < length; index++) {
|
|
|
|