From 5ad63a48c546716580c295ac46e41a51c24a2299 Mon Sep 17 00:00:00 2001 From: lzk Date: Fri, 19 Feb 2021 23:57:06 -0800 Subject: [PATCH] prelu simd optimize --- mindspore/lite/nnacl/fp32/activation_fp32.c | 8 +- mindspore/lite/nnacl/fp32/add_fp32.c | 4 +- mindspore/lite/nnacl/fp32/mul_fp32.c | 4 +- mindspore/lite/nnacl/fp32/prelu_fp32.c | 167 +++++------------- mindspore/lite/nnacl/fp32/prelu_fp32.h | 28 +++ .../nnacl/intrinsics/ms_simd_instructions.h | 54 +++++- .../src/runtime/kernel/arm/fp32/prelu_fp32.cc | 2 +- 7 files changed, 131 insertions(+), 136 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/activation_fp32.c b/mindspore/lite/nnacl/fp32/activation_fp32.c index 6c674a8150..817a20d7dc 100644 --- a/mindspore/lite/nnacl/fp32/activation_fp32.c +++ b/mindspore/lite/nnacl/fp32/activation_fp32.c @@ -77,8 +77,8 @@ int LRelu(const float *src, int length, float *dst, float alpha) { for (; i < length - 8; i += 8) { MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i); MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha); - MS_FLOAT32X8 mask = MS_CMP256_PS(src_tmp, MS_MOV256_F32(0.0f), 30); - MS_ST256_F32(dst + i, MS_BLEND256_PS(mul_tmp, src_tmp, mask)); + MS_FLOAT32X8 mask = MS_CMP256_F32(src_tmp, MS_MOV256_F32(0.0f), 30); + MS_ST256_F32(dst + i, MS_BLEND256_F32(mul_tmp, src_tmp, mask)); } #endif @@ -86,8 +86,8 @@ int LRelu(const float *src, int length, float *dst, float alpha) { for (; i < length - 4; i += 4) { MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i); MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha); - MS_FLOAT32X4 mask = MS_CMPGTQ_PS(src_tmp, MS_MOVQ_F32(0.0f)); - MS_STQ_F32(dst + i, MS_BLENDQ_PS(mul_tmp, src_tmp, mask)); + MS_FLOAT32X4 mask = MS_CMPGTQ_F32(src_tmp, MS_MOVQ_F32(0.0f)); + MS_STQ_F32(dst + i, MS_BLENDQ_F32(mul_tmp, src_tmp, mask)); } #endif for (; i < length; ++i) { diff --git a/mindspore/lite/nnacl/fp32/add_fp32.c b/mindspore/lite/nnacl/fp32/add_fp32.c index b14ee6e30b..0b744e7d1a 100644 --- a/mindspore/lite/nnacl/fp32/add_fp32.c +++ b/mindspore/lite/nnacl/fp32/add_fp32.c @@ -262,7 +262,7 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1); - vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30)); + vout = MS_BLEND256_F32(zeros_8, vout, MS_CMP256_F32(vout, zeros_8, 30)); MS_ST256_F32(out + index, vout); } #endif @@ -272,7 +272,7 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1); - vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros)); + vout = MS_BLENDQ_F32(zeros, vout, MS_CMPGTQ_F32(vout, zeros)); MS_STQ_F32(out + index, vout); } #endif diff --git a/mindspore/lite/nnacl/fp32/mul_fp32.c b/mindspore/lite/nnacl/fp32/mul_fp32.c index 9ccb86c72f..8a504de094 100644 --- a/mindspore/lite/nnacl/fp32/mul_fp32.c +++ b/mindspore/lite/nnacl/fp32/mul_fp32.c @@ -54,7 +54,7 @@ int ElementMulRelu(const float *in0, const float *in1, float *out, int size) { MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1); - vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30)); + vout = MS_BLEND256_F32(zeros_8, vout, MS_CMP256_F32(vout, zeros_8, 30)); MS_ST256_F32(out + index, vout); } #endif @@ -64,7 +64,7 @@ int ElementMulRelu(const float *in0, const float *in1, float *out, int size) { MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1); - vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros)); + vout = MS_BLENDQ_F32(zeros, vout, MS_CMPGTQ_F32(vout, zeros)); MS_STQ_F32(out + index, vout); } #endif diff --git a/mindspore/lite/nnacl/fp32/prelu_fp32.c b/mindspore/lite/nnacl/fp32/prelu_fp32.c index afe513bafd..65b1b1fd6c 100644 --- a/mindspore/lite/nnacl/fp32/prelu_fp32.c +++ b/mindspore/lite/nnacl/fp32/prelu_fp32.c @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,8 @@ * limitations under the License. */ #include "nnacl/fp32/prelu_fp32.h" -#ifdef ENABLE_NEON -#include -#endif void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int plane) { -#ifdef ENABLE_ARM - float32x4_t zero_value = vdupq_n_f32(0); -#endif int plane_tile = plane / TILE_NUM * TILE_NUM; int channel_num = prelu_param_->channel_num_; int plane_index = 0; @@ -29,37 +23,29 @@ void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int float *in_plane_ptr = input + plane_index * channel_num; float *out_plane_ptr = output + plane_index * channel_num; int channel_index = 0; -#ifdef ENABLE_ARM +#if defined(ENABLE_AVX) + MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f); + float *negetive_slope_value_8 = prelu_param_->slope_; + int div_channel_c8 = prelu_param_->channel_num_ / C8NUM * C8NUM; + for (; channel_index < div_channel_c8; channel_index += C8NUM) { + MS_FLOAT32X8 slope_value_8 = MS_LD256_F32(negetive_slope_value_8 + channel_index); + LOAD256X8_F32(src, in_plane_ptr + channel_index, channel_num) + PRELU_CALCULATE_256X8(dst, src) + STORE256X8_F32(out_plane_ptr + channel_index, channel_num, dst) + } +#endif + // note: First AVX processing, then SSE processing on X86 platform +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f); float *negetive_slope_value = prelu_param_->slope_; int div_channel = prelu_param_->channel_num_ / C4NUM * C4NUM; for (; channel_index < div_channel; channel_index += C4NUM) { - float32x4_t slope_value = vld1q_f32(negetive_slope_value + channel_index); - float32x4_t v1 = vld1q_f32(in_plane_ptr + channel_index + 0 * channel_num); - float32x4_t v2 = vld1q_f32(in_plane_ptr + channel_index + 1 * channel_num); - float32x4_t v3 = vld1q_f32(in_plane_ptr + channel_index + 2 * channel_num); - float32x4_t v4 = vld1q_f32(in_plane_ptr + channel_index + 3 * channel_num); - float32x4_t v5 = vld1q_f32(in_plane_ptr + channel_index + 4 * channel_num); - float32x4_t v6 = vld1q_f32(in_plane_ptr + channel_index + 5 * channel_num); - float32x4_t v7 = vld1q_f32(in_plane_ptr + channel_index + 6 * channel_num); - float32x4_t v8 = vld1q_f32(in_plane_ptr + channel_index + 7 * channel_num); - - float32x4_t r1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value)); - float32x4_t r2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value)); - float32x4_t r3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value)); - float32x4_t r4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value)); - float32x4_t r5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value)); - float32x4_t r6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value)); - float32x4_t r7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value)); - float32x4_t r8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value)); - - vst1q_f32(out_plane_ptr + channel_index + 0 * channel_num, r1); - vst1q_f32(out_plane_ptr + channel_index + 1 * channel_num, r2); - vst1q_f32(out_plane_ptr + channel_index + 2 * channel_num, r3); - vst1q_f32(out_plane_ptr + channel_index + 3 * channel_num, r4); - vst1q_f32(out_plane_ptr + channel_index + 4 * channel_num, r5); - vst1q_f32(out_plane_ptr + channel_index + 5 * channel_num, r6); - vst1q_f32(out_plane_ptr + channel_index + 6 * channel_num, r7); - vst1q_f32(out_plane_ptr + channel_index + 7 * channel_num, r8); + MS_FLOAT32X4 slope_value = MS_LDQ_F32(negetive_slope_value + channel_index); + LOAD128X8_F32(src, in_plane_ptr + channel_index, channel_num) + PRELU_CALCULATE_128X8(dst, src) + STORE128X8_F32(out_plane_ptr + channel_index, channel_num, dst) } #endif for (; channel_index < channel_num; channel_index++) { @@ -88,100 +74,41 @@ void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_param_, int task_id) { for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) { int cal_index; -#ifdef ENABLE_NEON - float32x4_t slope_value = vdupq_n_f32(prelu_param_->slope_[0]); - float32x4_t zero_value = vdupq_n_f32(0); -#endif -#ifdef ENABLE_ARM64 +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) cal_index = j * 64; - -#elif ENABLE_ARM32 - cal_index = j * 32; #else cal_index = j * 32; - const int cal_per_time = 32; #endif + float *input_ptr = input + cal_index; float *output_ptr = input + cal_index; -#ifdef ENABLE_ARM64 - float32x4_t v1 = vld1q_f32(input_ptr); - float32x4_t v2 = vld1q_f32(input_ptr + 4); - float32x4_t v3 = vld1q_f32(input_ptr + 8); - float32x4_t v4 = vld1q_f32(input_ptr + 12); - float32x4_t v5 = vld1q_f32(input_ptr + 16); - float32x4_t v6 = vld1q_f32(input_ptr + 20); - float32x4_t v7 = vld1q_f32(input_ptr + 24); - float32x4_t v8 = vld1q_f32(input_ptr + 28); - float32x4_t v9 = vld1q_f32(input_ptr + 32); - float32x4_t v10 = vld1q_f32(input_ptr + 36); - float32x4_t v11 = vld1q_f32(input_ptr + 40); - float32x4_t v12 = vld1q_f32(input_ptr + 44); - float32x4_t v13 = vld1q_f32(input_ptr + 48); - float32x4_t v14 = vld1q_f32(input_ptr + 52); - float32x4_t v15 = vld1q_f32(input_ptr + 56); - float32x4_t v16 = vld1q_f32(input_ptr + 60); - - float32x4_t t1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value)); - float32x4_t t2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value)); - float32x4_t t3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value)); - float32x4_t t4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value)); - float32x4_t t5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value)); - float32x4_t t6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value)); - float32x4_t t7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value)); - float32x4_t t8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value)); - float32x4_t t9 = vaddq_f32(vmulq_f32(vminq_f32(v9, zero_value), slope_value), vmaxq_f32(v9, zero_value)); - float32x4_t t10 = vaddq_f32(vmulq_f32(vminq_f32(v10, zero_value), slope_value), vmaxq_f32(v10, zero_value)); - float32x4_t t11 = vaddq_f32(vmulq_f32(vminq_f32(v11, zero_value), slope_value), vmaxq_f32(v11, zero_value)); - float32x4_t t12 = vaddq_f32(vmulq_f32(vminq_f32(v12, zero_value), slope_value), vmaxq_f32(v12, zero_value)); - float32x4_t t13 = vaddq_f32(vmulq_f32(vminq_f32(v13, zero_value), slope_value), vmaxq_f32(v13, zero_value)); - float32x4_t t14 = vaddq_f32(vmulq_f32(vminq_f32(v14, zero_value), slope_value), vmaxq_f32(v14, zero_value)); - float32x4_t t15 = vaddq_f32(vmulq_f32(vminq_f32(v15, zero_value), slope_value), vmaxq_f32(v15, zero_value)); - float32x4_t t16 = vaddq_f32(vmulq_f32(vminq_f32(v16, zero_value), slope_value), vmaxq_f32(v16, zero_value)); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0); + MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f); + MS_FLOAT32X8 slope_value_8 = MS_MOV256_F32(prelu_param_->slope_[0]); + LOAD256X8_F32(src, input_ptr, 8) + PRELU_CALCULATE_256X8(dst, src) + STORE256X8_F32(output_ptr, 8, dst) +#elif defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) + MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0); + MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f); + MS_FLOAT32X4 slope_value = MS_MOVQ_F32(prelu_param_->slope_[0]); - vst1q_f32(output_ptr, t1); - vst1q_f32(output_ptr + 4, t2); - vst1q_f32(output_ptr + 8, t3); - vst1q_f32(output_ptr + 12, t4); - vst1q_f32(output_ptr + 16, t5); - vst1q_f32(output_ptr + 20, t6); - vst1q_f32(output_ptr + 24, t7); - vst1q_f32(output_ptr + 28, t8); - vst1q_f32(output_ptr + 32, t9); - vst1q_f32(output_ptr + 36, t10); - vst1q_f32(output_ptr + 40, t11); - vst1q_f32(output_ptr + 44, t12); - vst1q_f32(output_ptr + 48, t13); - vst1q_f32(output_ptr + 52, t14); - vst1q_f32(output_ptr + 56, t15); - vst1q_f32(output_ptr + 60, t16); -#elif ENABLE_ARM32 - float32x4_t v1 = vld1q_f32(input_ptr); - float32x4_t v2 = vld1q_f32(input_ptr + 4); - float32x4_t v3 = vld1q_f32(input_ptr + 8); - float32x4_t v4 = vld1q_f32(input_ptr + 12); - float32x4_t v5 = vld1q_f32(input_ptr + 16); - float32x4_t v6 = vld1q_f32(input_ptr + 20); - float32x4_t v7 = vld1q_f32(input_ptr + 24); - float32x4_t v8 = vld1q_f32(input_ptr + 28); - - float32x4_t t1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value)); - float32x4_t t2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value)); - float32x4_t t3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value)); - float32x4_t t4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value)); - float32x4_t t5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value)); - float32x4_t t6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value)); - float32x4_t t7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value)); - float32x4_t t8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value)); + LOAD128X8_F32(src, input_ptr, 4) +#ifdef ENABLE_ARM64 + LOAD128X8_F32(src1, input_ptr + 32, 4) +#endif + PRELU_CALCULATE_128X8(dst, src) +#ifdef ENABLE_ARM64 + PRELU_CALCULATE_128X8(dst1, src1) +#endif + STORE128X8_F32(output_ptr, 4, dst) +#ifdef ENABLE_ARM64 + STORE128X8_F32(output_ptr + 32, 4, dst1) +#endif - vst1q_f32(output_ptr, t1); - vst1q_f32(output_ptr + 4, t2); - vst1q_f32(output_ptr + 8, t3); - vst1q_f32(output_ptr + 12, t4); - vst1q_f32(output_ptr + 16, t5); - vst1q_f32(output_ptr + 20, t6); - vst1q_f32(output_ptr + 24, t7); - vst1q_f32(output_ptr + 28, t8); #else + const int cal_per_time = 32; for (int i = 0; i < cal_per_time; ++i) { float data = input_ptr[i]; output_ptr[i] = (data < 0 ? data : 0) * prelu_param_->slope_[0] + (data > 0 ? data : 0); diff --git a/mindspore/lite/nnacl/fp32/prelu_fp32.h b/mindspore/lite/nnacl/fp32/prelu_fp32.h index c2977710fd..f9d16a2667 100644 --- a/mindspore/lite/nnacl/fp32/prelu_fp32.h +++ b/mindspore/lite/nnacl/fp32/prelu_fp32.h @@ -29,4 +29,32 @@ void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_ } #endif +#define PRELU_CALCULATE_256X8(dst, src) \ + MS_FLOAT32X8 dst##1 = \ + MS_MUL256_F32(src##1, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##1, zero_value_8, 30))); \ + MS_FLOAT32X8 dst##2 = \ + MS_MUL256_F32(src##2, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##2, zero_value_8, 30))); \ + MS_FLOAT32X8 dst##3 = \ + MS_MUL256_F32(src##3, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##3, zero_value_8, 30))); \ + MS_FLOAT32X8 dst##4 = \ + MS_MUL256_F32(src##4, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##4, zero_value_8, 30))); \ + MS_FLOAT32X8 dst##5 = \ + MS_MUL256_F32(src##5, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##5, zero_value_8, 30))); \ + MS_FLOAT32X8 dst##6 = \ + MS_MUL256_F32(src##6, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##6, zero_value_8, 30))); \ + MS_FLOAT32X8 dst##7 = \ + MS_MUL256_F32(src##7, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##7, zero_value_8, 30))); \ + MS_FLOAT32X8 dst##8 = \ + MS_MUL256_F32(src##8, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##8, zero_value_8, 30))); + +#define PRELU_CALCULATE_128X8(dst, src) \ + MS_FLOAT32X4 dst##1 = MS_MULQ_F32(src##1, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##1, zero_value))); \ + MS_FLOAT32X4 dst##2 = MS_MULQ_F32(src##2, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##2, zero_value))); \ + MS_FLOAT32X4 dst##3 = MS_MULQ_F32(src##3, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##3, zero_value))); \ + MS_FLOAT32X4 dst##4 = MS_MULQ_F32(src##4, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##4, zero_value))); \ + MS_FLOAT32X4 dst##5 = MS_MULQ_F32(src##5, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##5, zero_value))); \ + MS_FLOAT32X4 dst##6 = MS_MULQ_F32(src##6, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##6, zero_value))); \ + MS_FLOAT32X4 dst##7 = MS_MULQ_F32(src##7, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##7, zero_value))); \ + MS_FLOAT32X4 dst##8 = MS_MULQ_F32(src##8, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##8, zero_value))); + #endif // MINDSPORE_LITE_NNACL_FP32_PRELU_H_ diff --git a/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h b/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h index 7a21730213..8830d6146a 100644 --- a/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h +++ b/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h @@ -60,10 +60,10 @@ inline static float32x4_t vrecp(float32x4_t v) { #define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2)) #define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src) #define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src) -#define MS_CMPGTQ_PS(src1, src2) vcgtq_f32(src1, src2) +#define MS_CMPGTQ_F32(src1, src2) vcgtq_f32(src1, src2) #define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2) -// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_ps -#define MS_BLENDQ_PS(src1, src2, src3) vbslq_f32(src3, src2, src1) +// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32 +#define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1) #define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) #endif @@ -94,9 +94,9 @@ inline static float32x4_t vrecp(float32x4_t v) { #define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2) #define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src) #define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int -#define MS_CMP256_PS(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3) +#define MS_CMP256_F32(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3) #define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2) -#define MS_BLEND256_PS(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3) +#define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3) #define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3) #endif @@ -127,10 +127,50 @@ inline static float32x4_t vrecp(float32x4_t v) { #define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2) #define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int #define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src) -#define MS_CMPGTQ_PS(src1, src2) _mm_cmpgt_ps(src1, src2) +#define MS_CMPGTQ_F32(src1, src2) _mm_cmpgt_ps(src1, src2) #define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) -#define MS_BLENDQ_PS(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) +#define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) #define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) #endif +#define LOAD256X8_F32(src, input_ptr, num) \ + MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \ + MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \ + MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \ + MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \ + MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \ + MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \ + MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \ + MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); + +#define STORE256X8_F32(output_ptr, num, dst) \ + MS_ST256_F32(output_ptr + 0 * num, dst##1); \ + MS_ST256_F32(output_ptr + 1 * num, dst##2); \ + MS_ST256_F32(output_ptr + 2 * num, dst##3); \ + MS_ST256_F32(output_ptr + 3 * num, dst##4); \ + MS_ST256_F32(output_ptr + 4 * num, dst##5); \ + MS_ST256_F32(output_ptr + 5 * num, dst##6); \ + MS_ST256_F32(output_ptr + 6 * num, dst##7); \ + MS_ST256_F32(output_ptr + 7 * num, dst##8); + +#define LOAD128X8_F32(src, input_ptr, num) \ + MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \ + MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \ + MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \ + MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \ + MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \ + MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \ + MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \ + MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num); + +#define STORE128X8_F32(output_ptr, num, dst) \ + MS_STQ_F32(output_ptr + 0 * num, dst##1); \ + MS_STQ_F32(output_ptr + 1 * num, dst##2); \ + MS_STQ_F32(output_ptr + 2 * num, dst##3); \ + MS_STQ_F32(output_ptr + 3 * num, dst##4); \ + MS_STQ_F32(output_ptr + 4 * num, dst##5); \ + MS_STQ_F32(output_ptr + 5 * num, dst##6); \ + MS_STQ_F32(output_ptr + 6 * num, dst##7); \ + MS_STQ_F32(output_ptr + 7 * num, dst##8); + #endif // MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.cc index d7422f77bd..9547293bf7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/prelu_fp32.cc @@ -91,7 +91,7 @@ int PReluCPUKernel::ProcessShareChannelInput() { auto input_tensor = in_tensors_.at(0); prelu_param_->input_num_ = input_tensor->ElementsNum(); int tile = 32; -#ifdef ENABLE_ARM64 +#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) tile = 64; #endif prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, tile);