!12466 [ms][lite][cpu] prelu simd optimize

From: @lzkcode
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
pull/12466/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 02737b5e32

@ -77,8 +77,8 @@ int LRelu(const float *src, int length, float *dst, float alpha) {
for (; i < length - 8; i += 8) { for (; i < length - 8; i += 8) {
MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i); MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i);
MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha); MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha);
MS_FLOAT32X8 mask = MS_CMP256_PS(src_tmp, MS_MOV256_F32(0.0f), 30); MS_FLOAT32X8 mask = MS_CMP256_F32(src_tmp, MS_MOV256_F32(0.0f), 30);
MS_ST256_F32(dst + i, MS_BLEND256_PS(mul_tmp, src_tmp, mask)); MS_ST256_F32(dst + i, MS_BLEND256_F32(mul_tmp, src_tmp, mask));
} }
#endif #endif
@ -86,8 +86,8 @@ int LRelu(const float *src, int length, float *dst, float alpha) {
for (; i < length - 4; i += 4) { for (; i < length - 4; i += 4) {
MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i); MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i);
MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha); MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha);
MS_FLOAT32X4 mask = MS_CMPGTQ_PS(src_tmp, MS_MOVQ_F32(0.0f)); MS_FLOAT32X4 mask = MS_CMPGTQ_F32(src_tmp, MS_MOVQ_F32(0.0f));
MS_STQ_F32(dst + i, MS_BLENDQ_PS(mul_tmp, src_tmp, mask)); MS_STQ_F32(dst + i, MS_BLENDQ_F32(mul_tmp, src_tmp, mask));
} }
#endif #endif
for (; i < length; ++i) { for (; i < length; ++i) {

@ -262,7 +262,7 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1); MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1);
vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30)); vout = MS_BLEND256_F32(zeros_8, vout, MS_CMP256_F32(vout, zeros_8, 30));
MS_ST256_F32(out + index, vout); MS_ST256_F32(out + index, vout);
} }
#endif #endif
@ -272,7 +272,7 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1); MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1);
vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros)); vout = MS_BLENDQ_F32(zeros, vout, MS_CMPGTQ_F32(vout, zeros));
MS_STQ_F32(out + index, vout); MS_STQ_F32(out + index, vout);
} }
#endif #endif

@ -54,7 +54,7 @@ int ElementMulRelu(const float *in0, const float *in1, float *out, int size) {
MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index);
MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index);
MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1); MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1);
vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30)); vout = MS_BLEND256_F32(zeros_8, vout, MS_CMP256_F32(vout, zeros_8, 30));
MS_ST256_F32(out + index, vout); MS_ST256_F32(out + index, vout);
} }
#endif #endif
@ -64,7 +64,7 @@ int ElementMulRelu(const float *in0, const float *in1, float *out, int size) {
MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index);
MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index);
MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1); MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1);
vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros)); vout = MS_BLENDQ_F32(zeros, vout, MS_CMPGTQ_F32(vout, zeros));
MS_STQ_F32(out + index, vout); MS_STQ_F32(out + index, vout);
} }
#endif #endif

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,14 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include "nnacl/fp32/prelu_fp32.h" #include "nnacl/fp32/prelu_fp32.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int plane) { void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int plane) {
#ifdef ENABLE_ARM
float32x4_t zero_value = vdupq_n_f32(0);
#endif
int plane_tile = plane / TILE_NUM * TILE_NUM; int plane_tile = plane / TILE_NUM * TILE_NUM;
int channel_num = prelu_param_->channel_num_; int channel_num = prelu_param_->channel_num_;
int plane_index = 0; int plane_index = 0;
@ -29,37 +23,29 @@ void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int
float *in_plane_ptr = input + plane_index * channel_num; float *in_plane_ptr = input + plane_index * channel_num;
float *out_plane_ptr = output + plane_index * channel_num; float *out_plane_ptr = output + plane_index * channel_num;
int channel_index = 0; int channel_index = 0;
#ifdef ENABLE_ARM #if defined(ENABLE_AVX)
MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f);
float *negetive_slope_value_8 = prelu_param_->slope_;
int div_channel_c8 = prelu_param_->channel_num_ / C8NUM * C8NUM;
for (; channel_index < div_channel_c8; channel_index += C8NUM) {
MS_FLOAT32X8 slope_value_8 = MS_LD256_F32(negetive_slope_value_8 + channel_index);
LOAD256X8_F32(src, in_plane_ptr + channel_index, channel_num)
PRELU_CALCULATE_256X8(dst, src)
STORE256X8_F32(out_plane_ptr + channel_index, channel_num, dst)
}
#endif
// note: First AVX processing, then SSE processing on X86 platform
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f);
float *negetive_slope_value = prelu_param_->slope_; float *negetive_slope_value = prelu_param_->slope_;
int div_channel = prelu_param_->channel_num_ / C4NUM * C4NUM; int div_channel = prelu_param_->channel_num_ / C4NUM * C4NUM;
for (; channel_index < div_channel; channel_index += C4NUM) { for (; channel_index < div_channel; channel_index += C4NUM) {
float32x4_t slope_value = vld1q_f32(negetive_slope_value + channel_index); MS_FLOAT32X4 slope_value = MS_LDQ_F32(negetive_slope_value + channel_index);
float32x4_t v1 = vld1q_f32(in_plane_ptr + channel_index + 0 * channel_num); LOAD128X8_F32(src, in_plane_ptr + channel_index, channel_num)
float32x4_t v2 = vld1q_f32(in_plane_ptr + channel_index + 1 * channel_num); PRELU_CALCULATE_128X8(dst, src)
float32x4_t v3 = vld1q_f32(in_plane_ptr + channel_index + 2 * channel_num); STORE128X8_F32(out_plane_ptr + channel_index, channel_num, dst)
float32x4_t v4 = vld1q_f32(in_plane_ptr + channel_index + 3 * channel_num);
float32x4_t v5 = vld1q_f32(in_plane_ptr + channel_index + 4 * channel_num);
float32x4_t v6 = vld1q_f32(in_plane_ptr + channel_index + 5 * channel_num);
float32x4_t v7 = vld1q_f32(in_plane_ptr + channel_index + 6 * channel_num);
float32x4_t v8 = vld1q_f32(in_plane_ptr + channel_index + 7 * channel_num);
float32x4_t r1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value));
float32x4_t r2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value));
float32x4_t r3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value));
float32x4_t r4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value));
float32x4_t r5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value));
float32x4_t r6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value));
float32x4_t r7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value));
float32x4_t r8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value));
vst1q_f32(out_plane_ptr + channel_index + 0 * channel_num, r1);
vst1q_f32(out_plane_ptr + channel_index + 1 * channel_num, r2);
vst1q_f32(out_plane_ptr + channel_index + 2 * channel_num, r3);
vst1q_f32(out_plane_ptr + channel_index + 3 * channel_num, r4);
vst1q_f32(out_plane_ptr + channel_index + 4 * channel_num, r5);
vst1q_f32(out_plane_ptr + channel_index + 5 * channel_num, r6);
vst1q_f32(out_plane_ptr + channel_index + 6 * channel_num, r7);
vst1q_f32(out_plane_ptr + channel_index + 7 * channel_num, r8);
} }
#endif #endif
for (; channel_index < channel_num; channel_index++) { for (; channel_index < channel_num; channel_index++) {
@ -88,100 +74,41 @@ void PRelu(float *input, float *output, const PReluParameter *prelu_param_, int
void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_param_, int task_id) { void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_param_, int task_id) {
for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) { for (int j = task_id; j < prelu_param_->tile_block_; j += prelu_param_->op_parameter_.thread_num_) {
int cal_index; int cal_index;
#ifdef ENABLE_NEON #if defined(ENABLE_ARM64) || defined(ENABLE_AVX)
float32x4_t slope_value = vdupq_n_f32(prelu_param_->slope_[0]);
float32x4_t zero_value = vdupq_n_f32(0);
#endif
#ifdef ENABLE_ARM64
cal_index = j * 64; cal_index = j * 64;
#elif ENABLE_ARM32
cal_index = j * 32;
#else #else
cal_index = j * 32; cal_index = j * 32;
const int cal_per_time = 32;
#endif #endif
float *input_ptr = input + cal_index; float *input_ptr = input + cal_index;
float *output_ptr = input + cal_index; float *output_ptr = input + cal_index;
#ifdef ENABLE_ARM64 #if defined(ENABLE_AVX)
float32x4_t v1 = vld1q_f32(input_ptr); MS_FLOAT32X8 zero_value_8 = MS_MOV256_F32(0);
float32x4_t v2 = vld1q_f32(input_ptr + 4); MS_FLOAT32X8 one_value_8 = MS_MOV256_F32(1.0f);
float32x4_t v3 = vld1q_f32(input_ptr + 8); MS_FLOAT32X8 slope_value_8 = MS_MOV256_F32(prelu_param_->slope_[0]);
float32x4_t v4 = vld1q_f32(input_ptr + 12); LOAD256X8_F32(src, input_ptr, 8)
float32x4_t v5 = vld1q_f32(input_ptr + 16); PRELU_CALCULATE_256X8(dst, src)
float32x4_t v6 = vld1q_f32(input_ptr + 20); STORE256X8_F32(output_ptr, 8, dst)
float32x4_t v7 = vld1q_f32(input_ptr + 24); #elif defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
float32x4_t v8 = vld1q_f32(input_ptr + 28); MS_FLOAT32X4 zero_value = MS_MOVQ_F32(0);
float32x4_t v9 = vld1q_f32(input_ptr + 32); MS_FLOAT32X4 one_value = MS_MOVQ_F32(1.0f);
float32x4_t v10 = vld1q_f32(input_ptr + 36); MS_FLOAT32X4 slope_value = MS_MOVQ_F32(prelu_param_->slope_[0]);
float32x4_t v11 = vld1q_f32(input_ptr + 40);
float32x4_t v12 = vld1q_f32(input_ptr + 44);
float32x4_t v13 = vld1q_f32(input_ptr + 48);
float32x4_t v14 = vld1q_f32(input_ptr + 52);
float32x4_t v15 = vld1q_f32(input_ptr + 56);
float32x4_t v16 = vld1q_f32(input_ptr + 60);
float32x4_t t1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value));
float32x4_t t2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value));
float32x4_t t3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value));
float32x4_t t4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value));
float32x4_t t5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value));
float32x4_t t6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value));
float32x4_t t7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value));
float32x4_t t8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value));
float32x4_t t9 = vaddq_f32(vmulq_f32(vminq_f32(v9, zero_value), slope_value), vmaxq_f32(v9, zero_value));
float32x4_t t10 = vaddq_f32(vmulq_f32(vminq_f32(v10, zero_value), slope_value), vmaxq_f32(v10, zero_value));
float32x4_t t11 = vaddq_f32(vmulq_f32(vminq_f32(v11, zero_value), slope_value), vmaxq_f32(v11, zero_value));
float32x4_t t12 = vaddq_f32(vmulq_f32(vminq_f32(v12, zero_value), slope_value), vmaxq_f32(v12, zero_value));
float32x4_t t13 = vaddq_f32(vmulq_f32(vminq_f32(v13, zero_value), slope_value), vmaxq_f32(v13, zero_value));
float32x4_t t14 = vaddq_f32(vmulq_f32(vminq_f32(v14, zero_value), slope_value), vmaxq_f32(v14, zero_value));
float32x4_t t15 = vaddq_f32(vmulq_f32(vminq_f32(v15, zero_value), slope_value), vmaxq_f32(v15, zero_value));
float32x4_t t16 = vaddq_f32(vmulq_f32(vminq_f32(v16, zero_value), slope_value), vmaxq_f32(v16, zero_value));
vst1q_f32(output_ptr, t1); LOAD128X8_F32(src, input_ptr, 4)
vst1q_f32(output_ptr + 4, t2); #ifdef ENABLE_ARM64
vst1q_f32(output_ptr + 8, t3); LOAD128X8_F32(src1, input_ptr + 32, 4)
vst1q_f32(output_ptr + 12, t4); #endif
vst1q_f32(output_ptr + 16, t5); PRELU_CALCULATE_128X8(dst, src)
vst1q_f32(output_ptr + 20, t6); #ifdef ENABLE_ARM64
vst1q_f32(output_ptr + 24, t7); PRELU_CALCULATE_128X8(dst1, src1)
vst1q_f32(output_ptr + 28, t8); #endif
vst1q_f32(output_ptr + 32, t9); STORE128X8_F32(output_ptr, 4, dst)
vst1q_f32(output_ptr + 36, t10); #ifdef ENABLE_ARM64
vst1q_f32(output_ptr + 40, t11); STORE128X8_F32(output_ptr + 32, 4, dst1)
vst1q_f32(output_ptr + 44, t12); #endif
vst1q_f32(output_ptr + 48, t13);
vst1q_f32(output_ptr + 52, t14);
vst1q_f32(output_ptr + 56, t15);
vst1q_f32(output_ptr + 60, t16);
#elif ENABLE_ARM32
float32x4_t v1 = vld1q_f32(input_ptr);
float32x4_t v2 = vld1q_f32(input_ptr + 4);
float32x4_t v3 = vld1q_f32(input_ptr + 8);
float32x4_t v4 = vld1q_f32(input_ptr + 12);
float32x4_t v5 = vld1q_f32(input_ptr + 16);
float32x4_t v6 = vld1q_f32(input_ptr + 20);
float32x4_t v7 = vld1q_f32(input_ptr + 24);
float32x4_t v8 = vld1q_f32(input_ptr + 28);
float32x4_t t1 = vaddq_f32(vmulq_f32(vminq_f32(v1, zero_value), slope_value), vmaxq_f32(v1, zero_value));
float32x4_t t2 = vaddq_f32(vmulq_f32(vminq_f32(v2, zero_value), slope_value), vmaxq_f32(v2, zero_value));
float32x4_t t3 = vaddq_f32(vmulq_f32(vminq_f32(v3, zero_value), slope_value), vmaxq_f32(v3, zero_value));
float32x4_t t4 = vaddq_f32(vmulq_f32(vminq_f32(v4, zero_value), slope_value), vmaxq_f32(v4, zero_value));
float32x4_t t5 = vaddq_f32(vmulq_f32(vminq_f32(v5, zero_value), slope_value), vmaxq_f32(v5, zero_value));
float32x4_t t6 = vaddq_f32(vmulq_f32(vminq_f32(v6, zero_value), slope_value), vmaxq_f32(v6, zero_value));
float32x4_t t7 = vaddq_f32(vmulq_f32(vminq_f32(v7, zero_value), slope_value), vmaxq_f32(v7, zero_value));
float32x4_t t8 = vaddq_f32(vmulq_f32(vminq_f32(v8, zero_value), slope_value), vmaxq_f32(v8, zero_value));
vst1q_f32(output_ptr, t1);
vst1q_f32(output_ptr + 4, t2);
vst1q_f32(output_ptr + 8, t3);
vst1q_f32(output_ptr + 12, t4);
vst1q_f32(output_ptr + 16, t5);
vst1q_f32(output_ptr + 20, t6);
vst1q_f32(output_ptr + 24, t7);
vst1q_f32(output_ptr + 28, t8);
#else #else
const int cal_per_time = 32;
for (int i = 0; i < cal_per_time; ++i) { for (int i = 0; i < cal_per_time; ++i) {
float data = input_ptr[i]; float data = input_ptr[i];
output_ptr[i] = (data < 0 ? data : 0) * prelu_param_->slope_[0] + (data > 0 ? data : 0); output_ptr[i] = (data < 0 ? data : 0) * prelu_param_->slope_[0] + (data > 0 ? data : 0);

@ -29,4 +29,32 @@ void PReluShareChannel(float *input, float *output, const PReluParameter *prelu_
} }
#endif #endif
#define PRELU_CALCULATE_256X8(dst, src) \
MS_FLOAT32X8 dst##1 = \
MS_MUL256_F32(src##1, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##1, zero_value_8, 30))); \
MS_FLOAT32X8 dst##2 = \
MS_MUL256_F32(src##2, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##2, zero_value_8, 30))); \
MS_FLOAT32X8 dst##3 = \
MS_MUL256_F32(src##3, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##3, zero_value_8, 30))); \
MS_FLOAT32X8 dst##4 = \
MS_MUL256_F32(src##4, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##4, zero_value_8, 30))); \
MS_FLOAT32X8 dst##5 = \
MS_MUL256_F32(src##5, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##5, zero_value_8, 30))); \
MS_FLOAT32X8 dst##6 = \
MS_MUL256_F32(src##6, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##6, zero_value_8, 30))); \
MS_FLOAT32X8 dst##7 = \
MS_MUL256_F32(src##7, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##7, zero_value_8, 30))); \
MS_FLOAT32X8 dst##8 = \
MS_MUL256_F32(src##8, MS_BLEND256_F32(slope_value_8, one_value_8, MS_CMP256_F32(src##8, zero_value_8, 30)));
#define PRELU_CALCULATE_128X8(dst, src) \
MS_FLOAT32X4 dst##1 = MS_MULQ_F32(src##1, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##1, zero_value))); \
MS_FLOAT32X4 dst##2 = MS_MULQ_F32(src##2, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##2, zero_value))); \
MS_FLOAT32X4 dst##3 = MS_MULQ_F32(src##3, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##3, zero_value))); \
MS_FLOAT32X4 dst##4 = MS_MULQ_F32(src##4, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##4, zero_value))); \
MS_FLOAT32X4 dst##5 = MS_MULQ_F32(src##5, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##5, zero_value))); \
MS_FLOAT32X4 dst##6 = MS_MULQ_F32(src##6, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##6, zero_value))); \
MS_FLOAT32X4 dst##7 = MS_MULQ_F32(src##7, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##7, zero_value))); \
MS_FLOAT32X4 dst##8 = MS_MULQ_F32(src##8, MS_BLENDQ_F32(slope_value, one_value, MS_CMPGTQ_F32(src##8, zero_value)));
#endif // MINDSPORE_LITE_NNACL_FP32_PRELU_H_ #endif // MINDSPORE_LITE_NNACL_FP32_PRELU_H_

@ -60,10 +60,10 @@ inline static float32x4_t vrecp(float32x4_t v) {
#define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2)) #define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2))
#define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src) #define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src)
#define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src) #define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src)
#define MS_CMPGTQ_PS(src1, src2) vcgtq_f32(src1, src2) #define MS_CMPGTQ_F32(src1, src2) vcgtq_f32(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2) #define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2)
// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_ps // Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_f32
#define MS_BLENDQ_PS(src1, src2, src3) vbslq_f32(src3, src2, src1) #define MS_BLENDQ_F32(src1, src2, src3) vbslq_f32(src3, src2, src1)
#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) #define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1)
#endif #endif
@ -94,9 +94,9 @@ inline static float32x4_t vrecp(float32x4_t v) {
#define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2) #define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2)
#define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src) #define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src)
#define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int #define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int
#define MS_CMP256_PS(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3) #define MS_CMP256_F32(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3)
#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2) #define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2)
#define MS_BLEND256_PS(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3) #define MS_BLEND256_F32(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3)
#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3) #define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3)
#endif #endif
@ -127,10 +127,50 @@ inline static float32x4_t vrecp(float32x4_t v) {
#define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2) #define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2)
#define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int #define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int
#define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src) #define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src)
#define MS_CMPGTQ_PS(src1, src2) _mm_cmpgt_ps(src1, src2) #define MS_CMPGTQ_F32(src1, src2) _mm_cmpgt_ps(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) #define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2)
#define MS_BLENDQ_PS(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) #define MS_BLENDQ_F32(src1, src2, src3) _mm_blendv_ps(src1, src2, src3)
#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) #define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3)
#endif #endif
#define LOAD256X8_F32(src, input_ptr, num) \
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
#define STORE256X8_F32(output_ptr, num, dst) \
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
MS_ST256_F32(output_ptr + 2 * num, dst##3); \
MS_ST256_F32(output_ptr + 3 * num, dst##4); \
MS_ST256_F32(output_ptr + 4 * num, dst##5); \
MS_ST256_F32(output_ptr + 5 * num, dst##6); \
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
MS_ST256_F32(output_ptr + 7 * num, dst##8);
#define LOAD128X8_F32(src, input_ptr, num) \
MS_FLOAT32X4 src##1 = MS_LDQ_F32(input_ptr + 0 * num); \
MS_FLOAT32X4 src##2 = MS_LDQ_F32(input_ptr + 1 * num); \
MS_FLOAT32X4 src##3 = MS_LDQ_F32(input_ptr + 2 * num); \
MS_FLOAT32X4 src##4 = MS_LDQ_F32(input_ptr + 3 * num); \
MS_FLOAT32X4 src##5 = MS_LDQ_F32(input_ptr + 4 * num); \
MS_FLOAT32X4 src##6 = MS_LDQ_F32(input_ptr + 5 * num); \
MS_FLOAT32X4 src##7 = MS_LDQ_F32(input_ptr + 6 * num); \
MS_FLOAT32X4 src##8 = MS_LDQ_F32(input_ptr + 7 * num);
#define STORE128X8_F32(output_ptr, num, dst) \
MS_STQ_F32(output_ptr + 0 * num, dst##1); \
MS_STQ_F32(output_ptr + 1 * num, dst##2); \
MS_STQ_F32(output_ptr + 2 * num, dst##3); \
MS_STQ_F32(output_ptr + 3 * num, dst##4); \
MS_STQ_F32(output_ptr + 4 * num, dst##5); \
MS_STQ_F32(output_ptr + 5 * num, dst##6); \
MS_STQ_F32(output_ptr + 6 * num, dst##7); \
MS_STQ_F32(output_ptr + 7 * num, dst##8);
#endif // MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ #endif // MINDSPORE_LITE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_

@ -91,7 +91,7 @@ int PReluCPUKernel::ProcessShareChannelInput() {
auto input_tensor = in_tensors_.at(0); auto input_tensor = in_tensors_.at(0);
prelu_param_->input_num_ = input_tensor->ElementsNum(); prelu_param_->input_num_ = input_tensor->ElementsNum();
int tile = 32; int tile = 32;
#ifdef ENABLE_ARM64 #if defined(ENABLE_ARM64) || defined(ENABLE_AVX)
tile = 64; tile = 64;
#endif #endif
prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, tile); prelu_param_->tile_block_ = UP_DIV(prelu_param_->input_num_, tile);

Loading…
Cancel
Save