diff --git a/mindspore/lite/nnacl/fp32/activation.c b/mindspore/lite/nnacl/fp32/activation.c index 5305f96f4c..087b0f80b1 100644 --- a/mindspore/lite/nnacl/fp32/activation.c +++ b/mindspore/lite/nnacl/fp32/activation.c @@ -18,14 +18,31 @@ #include "nnacl/errorcode.h" int Fp32Relu(const float *src, int length, float *dst) { - for (int i = 0; i < length; ++i) { + int i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + for (; i < length - 4; i += 4) { + vst1q_f32(dst + i, vmaxq_f32(vld1q_f32(src + i), zero_4)); + } +#endif + for (; i < length; ++i) { dst[i] = src[i] > 0 ? src[i] : 0; } return NNACL_OK; } int Fp32Relu6(const float *src, int length, float *dst) { - for (int i = 0; i < length; ++i) { + int i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + float32x4_t six_4 = vdupq_n_f32(6.0f); + for (; i < length - 4; i += 4) { + float32x4_t dst_4 = vmaxq_f32(vld1q_f32(src + i), zero_4); + dst_4 = vminq_f32(dst_4, six_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { if (src[i] < 0) { dst[i] = 0; } else { @@ -36,7 +53,18 @@ int Fp32Relu6(const float *src, int length, float *dst) { } int LRelu(const float *src, int length, float *dst, float alpha) { - for (int i = 0; i < length; ++i) { + int i = 0; +#ifdef ENABLE_ARM64 + float32x4_t alpha_4 = vdupq_n_f32(alpha); + for (; i < length - 4; i += 4) { + float32x4_t src_4 = vld1q_f32(src + i); + float32x4_t mul_4 = vmulq_f32(src_4, alpha_4); + uint32x4_t flag = vclezq_f32(src_4); + float32x4_t dst_4 = vbslq_f32(flag, mul_4, src_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha); } return NNACL_OK;