From 37a30fb0fbcde2962943aa732887d098b0ad0582 Mon Sep 17 00:00:00 2001 From: lzk Date: Tue, 9 Feb 2021 18:47:27 -0800 Subject: [PATCH] fp32 optimize --- mindspore/lite/CMakeLists.txt | 2 + mindspore/lite/nnacl/fp32/activation_fp32.c | 149 +++-- mindspore/lite/nnacl/fp32/add_fp32.c | 277 ++++++--- mindspore/lite/nnacl/fp32/exp_fp32.h | 53 +- mindspore/lite/nnacl/fp32/mul_fp32.c | 428 ++++++++----- mindspore/lite/nnacl/fp32/winograd_utils.c | 603 ++++++++++--------- mindspore/lite/nnacl/op_base.h | 85 ++- mindspore/lite/src/ops/tensorlist_getitem.cc | 2 +- 8 files changed, 1034 insertions(+), 565 deletions(-) diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 61e529d96a..3aec1ff050 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -231,6 +231,8 @@ endif() if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64) if("${X86_64_SIMD}" STREQUAL "sse") add_compile_definitions(ENABLE_SSE) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1") endif() if("${X86_64_SIMD}" STREQUAL "avx") add_compile_definitions(ENABLE_SSE) diff --git a/mindspore/lite/nnacl/fp32/activation_fp32.c b/mindspore/lite/nnacl/fp32/activation_fp32.c index b3bb0d7539..b0d1fb68b8 100644 --- a/mindspore/lite/nnacl/fp32/activation_fp32.c +++ b/mindspore/lite/nnacl/fp32/activation_fp32.c @@ -20,10 +20,17 @@ int Fp32Relu(const float *src, int length, float *dst) { int i = 0; -#ifdef ENABLE_ARM - float32x4_t zero_4 = vdupq_n_f32(0.0f); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f); + for (; i < length - 8; i += 8) { + MS_ST256_F32(dst + i, MS_MAX256_F32(MS_LD256_F32(src + i), zero_8)); + } +#endif + +#if defined(ENABLE_SSE) || defined(ENABLE_ARM) + MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); for (; i < length - 4; i += 4) { - vst1q_f32(dst + i, vmaxq_f32(vld1q_f32(src + i), zero_4)); + MS_STQ_F32(dst + i, MS_MAXQ_F32(MS_LDQ_F32(src + i), zero)); } #endif for (; i < length; ++i) { @@ -34,13 +41,24 @@ int Fp32Relu(const float *src, int length, float *dst) { int Fp32Relu6(const float *src, int length, float *dst) { int i = 0; -#ifdef ENABLE_ARM - float32x4_t zero_4 = vdupq_n_f32(0.0f); - float32x4_t six_4 = vdupq_n_f32(6.0f); + +#if defined(ENABLE_AVX) + MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 six_8 = MS_MOV256_F32(6.0f); + for (; i < length - 8; i += 8) { + MS_FLOAT32X8 dst_tmp = MS_MAX256_F32(MS_LD256_F32(src + i), zero_8); + dst_tmp = MS_MIN256_F32(dst_tmp, six_8); + MS_ST256_F32(dst + i, dst_tmp); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 six = MS_MOVQ_F32(6.0f); for (; i < length - 4; i += 4) { - float32x4_t dst_4 = vmaxq_f32(vld1q_f32(src + i), zero_4); - dst_4 = vminq_f32(dst_4, six_4); - vst1q_f32(dst + i, dst_4); + MS_FLOAT32X4 dst_tmp = MS_MAXQ_F32(MS_LDQ_F32(src + i), zero); + dst_tmp = MS_MINQ_F32(dst_tmp, six); + MS_STQ_F32(dst + i, dst_tmp); } #endif for (; i < length; ++i) { @@ -55,14 +73,21 @@ int Fp32Relu6(const float *src, int length, float *dst) { int LRelu(const float *src, int length, float *dst, float alpha) { int i = 0; -#ifdef ENABLE_ARM64 - float32x4_t alpha_4 = vdupq_n_f32(alpha); +#if defined(ENABLE_AVX) + for (; i < length - 8; i += 8) { + MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i); + MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha); + MS_FLOAT32X8 mask = MS_CMP256_PS(src_tmp, MS_MOV256_F32(0.0f), 30); + MS_ST256_F32(dst + i, MS_BLEND256_PS(mul_tmp, src_tmp, mask)); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) for (; i < length - 4; i += 4) { - float32x4_t src_4 = vld1q_f32(src + i); - float32x4_t mul_4 = vmulq_f32(src_4, alpha_4); - uint32x4_t flag = vclezq_f32(src_4); - float32x4_t dst_4 = vbslq_f32(flag, mul_4, src_4); - vst1q_f32(dst + i, dst_4); + MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i); + MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha); + MS_FLOAT32X4 mask = MS_CMPGTQ_PS(src_tmp, MS_MOVQ_F32(0.0f)); + MS_STQ_F32(dst + i, MS_BLENDQ_PS(mul_tmp, src_tmp, mask)); } #endif for (; i < length; ++i) { @@ -73,11 +98,18 @@ int LRelu(const float *src, int length, float *dst, float alpha) { int Sigmoid(const float *src, int length, float *dst) { int i = 0; -#ifdef ENABLE_ARM64 - int count = (length / C4NUM) * C4NUM; - for (; i < count; i += C4NUM) { - simd_exp(vnegq_f32(vld1q_f32(src + i)), dst + i); - vst1q_f32(dst + i, vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), vld1q_f32(dst + i)))); +#if defined(ENABLE_AVX) + for (; i < length - 8; i += 8) { + simd_exp_avx(-(MS_LD256_F32(src + i)), dst + i); + MS_ST256_F32(dst + i, + MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i)))); + } +#endif + +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) + for (; i < length - 4; i += 4) { + simd_exp(-(MS_LDQ_F32(src + i)), dst + i); + MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i)))); } #endif for (; i < length; ++i) { @@ -102,26 +134,40 @@ float TanhOpt(float src) { int Tanh(const float *src, int length, float *dst) { int i = 0; -#ifdef ENABLE_ARM64 - static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, - {17325.0f, 17325.0f, 17325.0f, 17325.0f}, - {135135.0f, 135135.0f, 135135.0f, 135135.0f}, - {28.0f, 28.0f, 28.0f, 28.0f}, - {3150.0f, 3150.0f, 3150.0f, 3150.0f}, - {62370.0f, 62370.0f, 62370.0f, 62370.0f}}; - float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; - float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; - int count = (length / C4NUM) * C4NUM; - for (; i < count; i += C4NUM) { - float32x4_t input = vld1q_f32(src + i); - float32x4_t square = vmulq_f32(input, input); - float32x4_t a = vmulq_f32( - vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]), - input); - float32x4_t b = vaddq_f32( - vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), - paramv[2]); - vst1q_f32(dst + i, vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one)); +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) || defined(ENABLE_AVX) + const int cnt = 6; + float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f}; +#endif + +#if defined(ENABLE_AVX) + MS_FLOAT32X8 neg_one_8 = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; + MS_FLOAT32X8 pos_one_8 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + MS_FLOAT32X8 param256[6]; + for (int j = 0; j < cnt; ++j) { + param256[j] = MS_MOV256_F32(data[j]); + } + for (; i < length - 8; i += 8) { + MS_FLOAT32X8 input = MS_LD256_F32(src + i); + MS_FLOAT32X8 square = input * input; + MS_FLOAT32X8 a = (((square + param256[0]) * square + param256[1]) * square + param256[2]) * input; + MS_FLOAT32X8 b = ((param256[3] * square + param256[4]) * square + param256[5]) * square + param256[2]; + MS_ST256_F32(dst + i, MS_MIN256_F32(MS_MAX256_F32(a / b, neg_one_8), pos_one_8)); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + MS_FLOAT32X4 param[6]; + MS_FLOAT32X4 neg_one = {-1.0f, -1.0f, -1.0f, -1.0f}; + MS_FLOAT32X4 pos_one = {1.0f, 1.0f, 1.0f, 1.0f}; + for (int j = 0; j < cnt; ++j) { + param[j] = MS_MOVQ_F32(data[j]); + } + for (; i < length - 4; i += 4) { + MS_FLOAT32X4 input = MS_LDQ_F32(src + i); + MS_FLOAT32X4 square = input * input; + MS_FLOAT32X4 a = (((square + param[0]) * square + param[1]) * square + param[2]) * input; + MS_FLOAT32X4 b = ((param[3] * square + param[4]) * square + param[5]) * square + param[2]; + MS_STQ_F32(dst + i, MS_MINQ_F32(MS_MAXQ_F32(a / b, neg_one), pos_one)); } #endif for (; i < length; ++i) { @@ -142,12 +188,21 @@ int Swish(const float *src, int length, float *dst) { return NNACL_ERR; } int index = 0; -#ifdef ENABLE_NEON - for (; index <= length - C4NUM; index += C4NUM) { - float32x4_t src_value = vld1q_f32(src + index); - float32x4_t sigmoid_value = vld1q_f32(dst + index); - float32x4_t result = vmulq_f32(src_value, sigmoid_value); - vst1q_f32(dst + index, result); +#if defined(ENABLE_AVX) + for (; index <= length - 8; index += 8) { + MS_FLOAT32X8 src_value = MS_LD256_F32(src + index); + MS_FLOAT32X8 sigmoid_value = MS_LD256_F32(dst + index); + MS_FLOAT32X8 result = MS_MUL256_F32(src_value, sigmoid_value); + MS_ST256_F32(dst + index, result); + } +#endif + +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) + for (; index <= length - 4; index += 4) { + MS_FLOAT32X4 src_value = MS_LDQ_F32(src + index); + MS_FLOAT32X4 sigmoid_value = MS_LDQ_F32(dst + index); + MS_FLOAT32X4 result = MS_MULQ_F32(src_value, sigmoid_value); + MS_STQ_F32(dst + index, result); } #endif for (; index < length; index++) { diff --git a/mindspore/lite/nnacl/fp32/add_fp32.c b/mindspore/lite/nnacl/fp32/add_fp32.c index 11b839766a..b14ee6e30b 100644 --- a/mindspore/lite/nnacl/fp32/add_fp32.c +++ b/mindspore/lite/nnacl/fp32/add_fp32.c @@ -18,28 +18,46 @@ #include "nnacl/fp32/arithmetic_fp32.h" int ElementOptAdd(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - float32x4_t vin0_opt = vdupq_n_f32(in0[0]); - float32x4_t vin1_opt = vdupq_n_f32(in1[0]); +#ifdef ENABLE_AVX + MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); + MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); + MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); #endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vaddq_f32(vin0_opt, vin1); - vst1q_f32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_ADD256_F32(vin0_opt_8, vin1); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0_opt, vin1); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { out[index] = in0[0] + in1[index]; } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vout = vaddq_f32(vin0, vin1_opt); - vst1q_f32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1_opt_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1_opt); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -50,28 +68,46 @@ int ElementOptAdd(const float *in0, const float *in1, float *out, int size, cons } int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - int32x4_t vin0_opt = vdupq_n_s32(in0[0]); - int32x4_t vin1_opt = vdupq_n_s32(in1[0]); +#ifdef ENABLE_AVX + MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]); + MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]); +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]); + MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]); #endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin1 = vld1q_s32(in1 + index); - int32x4_t vout = vaddq_s32(vin0_opt, vin1); - vst1q_s32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); + MS_INT32X8 vout = MS_ADD256_EPI32(vin0_opt_8, vin1); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); + MS_INT32X4 vout = MS_ADDQ_EPI32(vin0_opt, vin1); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { out[index] = in0[0] + in1[index]; } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin0 = vld1q_s32(in0 + index); - int32x4_t vout = vaddq_s32(vin0, vin1_opt); - vst1q_s32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); + MS_INT32X8 vout = MS_ADD256_EPI32(vin0, vin1_opt_8); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); + MS_INT32X4 vout = MS_ADDQ_EPI32(vin0, vin1_opt); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { @@ -82,29 +118,48 @@ int ElementOptAddInt(const int *in0, const int *in1, int *out, int size, const A } int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - float32x4_t vin0_opt = vdupq_n_f32(in0[0]); - float32x4_t vin1_opt = vdupq_n_f32(in1[0]); - float32x4_t zeros = vdupq_n_f32(0.0f); +#ifdef ENABLE_AVX + MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); + MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); + MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); #endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vmaxq_f32(vaddq_f32(vin0_opt, vin1), zeros); - vst1q_f32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_MAX256_F32(MS_ADD256_F32(vin0_opt_8, vin1), zeros_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_ADDQ_F32(vin0_opt, vin1), zeros); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { out[index] = MSMAX(in0[0] + in1[index], 0); } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1_opt), zeros); - vst1q_f32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vout = MS_MAX256_F32(MS_ADD256_F32(vin0, vin1_opt_8), zeros_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_ADDQ_F32(vin0, vin1_opt), zeros); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -115,30 +170,50 @@ int ElementOptAddRelu(const float *in0, const float *in1, float *out, int size, } int ElementOptAddRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - float32x4_t vin0_opt = vdupq_n_f32(in0[0]); - float32x4_t vin1_opt = vdupq_n_f32(in1[0]); - float32x4_t zeros = vdupq_n_f32(0.0f); - float32x4_t bounds = vdupq_n_f32(6.0f); +#ifdef ENABLE_AVX + MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); + MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); + MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); #endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0_opt, vin1), zeros), bounds); - vst1q_f32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_ADD256_F32(vin0_opt_8, vin1), zeros_8), bounds_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_ADDQ_F32(vin0_opt, vin1), zeros), bounds); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { out[index] = MSMIN(MSMAX(in0[0] + in1[index], 0), 6); } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1_opt), zeros), bounds); - vst1q_f32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_ADD256_F32(vin0, vin1_opt_8), zeros_8), bounds_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_ADDQ_F32(vin0, vin1_opt), zeros), bounds); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -157,12 +232,20 @@ int BroadcastAdd(const float *in0, const float *in1, float *tile_in0, float *til int ElementAdd(const float *in0, const float *in1, float *out, int size) { int index = 0; -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vaddq_f32(vin0, vin1); - vst1q_f32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -173,14 +256,24 @@ int ElementAdd(const float *in0, const float *in1, float *out, int size) { int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { int index = 0; -#ifdef ENABLE_NEON - float32x4_t zeros = vdupq_n_f32(0.0f); - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vaddq_f32(vin0, vin1); - vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); - vst1q_f32(out + index, vout); +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_ADD256_F32(vin0, vin1); + vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30)); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_ADDQ_F32(vin0, vin1); + vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros)); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -192,14 +285,24 @@ int ElementAddRelu(const float *in0, const float *in1, float *out, int size) { int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) { int index = 0; -#ifdef ENABLE_NEON - float32x4_t zeros = vdupq_n_f32(0.0f); - float32x4_t bounds = vdupq_n_f32(6.0f); - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds); - vst1q_f32(out + index, vout); +#ifdef ENABLE_AVX + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_ADD256_F32(vin0, vin1), zeros_8), bounds_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_ADDQ_F32(vin0, vin1), zeros), bounds); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -210,12 +313,20 @@ int ElementAddRelu6(const float *in0, const float *in1, float *out, int size) { int ElementAddInt(const int *in0, const int *in1, int *out, int size) { int index = 0; -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin0 = vld1q_s32(in0 + index); - int32x4_t vin1 = vld1q_s32(in1 + index); - int32x4_t vout = vaddq_s32(vin0, vin1); - vst1q_s32(out + index, vout); +#ifdef ENABLE_AVX + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); + MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); + MS_INT32X8 vout = MS_ADD256_EPI32(vin0, vin1); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); + MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); + MS_INT32X4 vout = MS_ADDQ_EPI32(vin0, vin1); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { diff --git a/mindspore/lite/nnacl/fp32/exp_fp32.h b/mindspore/lite/nnacl/fp32/exp_fp32.h index 5c4ea5e41b..09ae4c5afb 100644 --- a/mindspore/lite/nnacl/fp32/exp_fp32.h +++ b/mindspore/lite/nnacl/fp32/exp_fp32.h @@ -38,26 +38,49 @@ extern "C" { int Exp(const float *input_data, float *output_data, const ExpParameter *parameter, int task_id); void ExpFp32(const float *src, float *dst, int num); -#ifdef ENABLE_ARM64 -static inline void simd_exp(float32x4_t input4, float *dst) { - static float32x4_t maxv = {88.0f, 88.0f, 88.0f, 88.0f}; - static float32x4_t minv = {-88.0f, -88.0f, -88.0f, -88.0f}; - static float32x4_t paramv[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f}, +#if defined(ENABLE_ARM64) || defined(ENABLE_SSE) +static inline void simd_exp(MS_FLOAT32X4 input, float *dst) { + static MS_FLOAT32X4 maxv = {88.0f, 88.0f, 88.0f, 88.0f}; + static MS_FLOAT32X4 minv = {-88.0f, -88.0f, -88.0f, -88.0f}; + static MS_FLOAT32X4 param[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f}, {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, {0.5f, 0.5f, 0.5f, 0.5f}, {1.0f, 1.0f, 1.0f, 1.0f}}; - input4 = vmaxq_f32(minv, vminq_f32(maxv, input4)); - int32x4_t integer4 = vcvtq_s32_f32(vdivq_f32(input4, paramv[0])); - float32x4_t decimal4 = vsubq_f32(input4, vmulq_f32(vcvtq_f32_s32(integer4), paramv[0])); - int32x4_t int_exp4 = vshlq_s32(vaddq_s32(integer4, vdupq_n_s32(127)), vdupq_n_s32(23)); - vst1q_f32(dst, vld1q_f32((float32_t *)(&int_exp4))); - float32x4_t decimal_exp4 = vaddq_f32(paramv[2], vmulq_f32(decimal4, paramv[1])); - decimal_exp4 = vmulq_f32(decimal4, vaddq_f32(paramv[3], vmulq_f32(decimal4, decimal_exp4))); - decimal_exp4 = vaddq_f32(paramv[5], vmulq_f32(decimal4, vaddq_f32(paramv[4], decimal_exp4))); - decimal_exp4 = vaddq_f32(paramv[5], vmulq_f32(decimal4, decimal_exp4)); - vst1q_f32(dst, vmulq_f32(vld1q_f32(dst), decimal_exp4)); + + input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv)); + MS_INT32X4 integer = MS_CVTQPS_EPI32(input / param[0]); + MS_FLOAT32X4 decimal = input - MS_CVTQEPI32_PS(integer) * param[0]; + MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(127)), 23); + memcpy(dst, &int_exp, sizeof(int32_t) * 4); + MS_FLOAT32X4 decimal_exp = + param[5] + + decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + MS_STQ_F32(dst, decimal_exp * MS_LDQ_F32(dst)); +} +#endif + +#if defined(ENABLE_AVX) +static inline void simd_exp_avx(MS_FLOAT32X8 input, float *dst) { + static MS_FLOAT32X8 maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f}; + static MS_FLOAT32X8 minv = {-88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f}; + static MS_FLOAT32X8 param[] = { + {0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f}, + {1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120}, + {1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24}, + {1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6}, + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}}; + input = MS_MAX256_F32(minv, MS_MIN256_F32(input, maxv)); + MS_INT32X8 integer = MS_CVT256PS_EPI32(input / param[0]); + MS_FLOAT32X8 decimal = input - MS_CVT256EPI32_PS(integer) * param[0]; + MS_INT32X8 int_exp = MS_SLLI256_EPI32(MS_ADD256_EPI32(integer, MS_MOV256_EPI32(127)), 23); + memcpy(dst, &int_exp, sizeof(int32_t) * 8); + MS_FLOAT32X8 decimal_exp = + param[5] + + decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1])))); + MS_ST256_F32(dst, decimal_exp * MS_LD256_F32(dst)); } #endif diff --git a/mindspore/lite/nnacl/fp32/mul_fp32.c b/mindspore/lite/nnacl/fp32/mul_fp32.c index 4fea298205..9ccb86c72f 100644 --- a/mindspore/lite/nnacl/fp32/mul_fp32.c +++ b/mindspore/lite/nnacl/fp32/mul_fp32.c @@ -24,12 +24,20 @@ int BroadcastMul(const float *in0, const float *in1, float *tile_in0, float *til int ElementMul(const float *in0, const float *in1, float *out, int size) { int index = 0; -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vmulq_f32(vin0, vin1); - vst1q_f32(out + index, vout); +#if defined(ENABLE_AVX) + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -40,14 +48,24 @@ int ElementMul(const float *in0, const float *in1, float *out, int size) { int ElementMulRelu(const float *in0, const float *in1, float *out, int size) { int index = 0; -#ifdef ENABLE_NEON - float32x4_t zeros = vdupq_n_f32(0.0f); - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vmulq_f32(vin0, vin1); - vout = vbslq_f32(vcgtq_f32(vout, zeros), vout, zeros); - vst1q_f32(out + index, vout); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1); + vout = MS_BLEND256_PS(zeros_8, vout, MS_CMP256_PS(vout, zeros_8, 30)); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1); + vout = MS_BLENDQ_PS(zeros, vout, MS_CMPGTQ_PS(vout, zeros)); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -59,14 +77,24 @@ int ElementMulRelu(const float *in0, const float *in1, float *out, int size) { int ElementMulRelu6(const float *in0, const float *in1, float *out, int size) { int index = 0; -#ifdef ENABLE_NEON - float32x4_t zeros = vdupq_n_f32(0.0f); - float32x4_t bounds = vdupq_n_f32(6.0f); - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds); - vst1q_f32(out + index, vout); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_MUL256_F32(vin0, vin1), zeros_8), bounds_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_MULQ_F32(vin0, vin1), zeros), bounds); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -77,12 +105,20 @@ int ElementMulRelu6(const float *in0, const float *in1, float *out, int size) { int ElementMulInt(const int *in0, const int *in1, int *out, int size) { int index = 0; -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin0 = vld1q_s32(in0 + index); - int32x4_t vin1 = vld1q_s32(in1 + index); - int32x4_t vout = vmulq_s32(vin0, vin1); - vst1q_s32(out + index, vout); +#if defined(ENABLE_AVX) + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); + MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); + MS_INT32X8 vout = MS_MUL256_EPI32(vin0, vin1); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); + MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); + MS_INT32X4 vout = MS_MULQ_EPI32(vin0, vin1); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { @@ -93,18 +129,28 @@ int ElementMulInt(const int *in0, const int *in1, int *out, int size) { int ElementMulReluInt(const int *in0, const int *in1, int *out, int size) { int index = 0; -#ifdef ENABLE_NEON - int32x4_t zeros = vdupq_n_s32(0); - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin0 = vld1q_s32(in0 + index); - int32x4_t vin1 = vld1q_s32(in1 + index); - int32x4_t vout = vmulq_s32(vin0, vin1); - vout = vbslq_s32(vcgtq_s32(vout, zeros), vout, zeros); - vst1q_s32(out + index, vout); +#if defined(ENABLE_AVX) + MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); + MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); + MS_INT32X8 vout = MS_MUL256_EPI32(vin0, vin1); + vout = MS_BLEND256_EPI32(zeros_8, vout, MS_CMPGT256_EPI32(vout, zeros_8)); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_INT32X4 zeros = MS_MOVQ_EPI32(0); + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); + MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); + MS_INT32X4 vout = MS_MULQ_EPI32(vin0, vin1); + vout = MS_BLENDQ_EPI32(zeros, vout, MS_CMPGTQ_EPI32(vout, zeros)); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { - float res = in0[index] * in1[index]; + int res = in0[index] * in1[index]; out[index] = res > 0 ? res : 0; } return NNACL_OK; @@ -112,14 +158,24 @@ int ElementMulReluInt(const int *in0, const int *in1, int *out, int size) { int ElementMulRelu6Int(const int *in0, const int *in1, int *out, int size) { int index = 0; -#ifdef ENABLE_NEON - int32x4_t zeros = vdupq_n_s32(0); - int32x4_t bounds = vdupq_n_s32(6); - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin0 = vld1q_s32(in0 + index); - int32x4_t vin1 = vld1q_s32(in1 + index); - int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1), zeros), bounds); - vst1q_s32(out + index, vout); +#if defined(ENABLE_AVX) + MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); + MS_INT32X8 bounds_8 = MS_MOV256_EPI32(6); + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); + MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); + MS_INT32X8 vout = MS_MIN256_EPI32(MS_MAX256_EPI32(MS_MUL256_EPI32(vin0, vin1), zeros_8), bounds_8); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_INT32X4 zeros = MS_MOVQ_EPI32(0); + MS_INT32X4 bounds = MS_MOVQ_EPI32(6); + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); + MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); + MS_INT32X4 vout = MS_MINQ_EPI32(MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0, vin1), zeros), bounds); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { @@ -129,28 +185,42 @@ int ElementMulRelu6Int(const int *in0, const int *in1, int *out, int size) { } int ElementOptMul(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - float32x4_t vin0_opt = vdupq_n_f32(in0[0]); - float32x4_t vin1_opt = vdupq_n_f32(in1[0]); -#endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vmulq_f32(vin0_opt, vin1); - vst1q_f32(out + index, vout); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_MUL256_F32(vin0_opt_8, vin1); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_MULQ_F32(vin0_opt, vin1); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { out[index] = in0[0] * in1[index]; } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vout = vmulq_f32(vin0, vin1_opt); - vst1q_f32(out + index, vout); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vout = MS_MUL256_F32(vin0, vin1_opt_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vout = MS_MULQ_F32(vin0, vin1_opt); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -161,29 +231,46 @@ int ElementOptMul(const float *in0, const float *in1, float *out, int size, cons } int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - float32x4_t vin0_opt = vdupq_n_f32(in0[0]); - float32x4_t vin1_opt = vdupq_n_f32(in1[0]); - float32x4_t zeros = vdupq_n_f32(0.0f); -#endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vmaxq_f32(vmulq_f32(vin0_opt, vin1), zeros); - vst1q_f32(out + index, vout); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_MAX256_F32(MS_MUL256_F32(vin0_opt_8, vin1), zeros_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_MULQ_F32(vin0_opt, vin1), zeros); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { out[index] = MSMAX(in0[0] * in1[index], 0); } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1_opt), zeros); - vst1q_f32(out + index, vout); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vout = MS_MAX256_F32(MS_MUL256_F32(vin0, vin1_opt_8), zeros_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vout = MS_MAXQ_F32(MS_MULQ_F32(vin0, vin1_opt), zeros); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -194,30 +281,50 @@ int ElementOptMulRelu(const float *in0, const float *in1, float *out, int size, } int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - float32x4_t vin0_opt = vdupq_n_f32(in0[0]); - float32x4_t vin1_opt = vdupq_n_f32(in1[0]); - float32x4_t zeros = vdupq_n_f32(0.0f); - float32x4_t bounds = vdupq_n_f32(6.0f); -#endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin1 = vld1q_f32(in1 + index); - float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0_opt, vin1), zeros), bounds); - vst1q_f32(out + index, vout); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 vin0_opt_8 = MS_MOV256_F32(in0[0]); + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin1 = MS_LD256_F32(in1 + index); + MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_MUL256_F32(vin0_opt_8, vin1), zeros_8), bounds_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 vin0_opt = MS_MOVQ_F32(in0[0]); + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin1 = MS_LDQ_F32(in1 + index); + MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_MULQ_F32(vin0_opt, vin1), zeros), bounds); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - float32x4_t vin0 = vld1q_f32(in0 + index); - float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1_opt), zeros), bounds); - vst1q_f32(out + index, vout); +#if defined(ENABLE_AVX) + MS_FLOAT32X8 vin1_opt_8 = MS_MOV256_F32(in1[0]); + MS_FLOAT32X8 zeros_8 = MS_MOV256_F32(0.0f); + MS_FLOAT32X8 bounds_8 = MS_MOV256_F32(6.0f); + for (; index <= size - C8NUM; index += C8NUM) { + MS_FLOAT32X8 vin0 = MS_LD256_F32(in0 + index); + MS_FLOAT32X8 vout = MS_MIN256_F32(MS_MAX256_F32(MS_MUL256_F32(vin0, vin1_opt_8), zeros_8), bounds_8); + MS_ST256_F32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_FLOAT32X4 vin1_opt = MS_MOVQ_F32(in1[0]); + MS_FLOAT32X4 zeros = MS_MOVQ_F32(0.0f); + MS_FLOAT32X4 bounds = MS_MOVQ_F32(6.0f); + for (; index <= size - C4NUM; index += C4NUM) { + MS_FLOAT32X4 vin0 = MS_LDQ_F32(in0 + index); + MS_FLOAT32X4 vout = MS_MINQ_F32(MS_MAXQ_F32(MS_MULQ_F32(vin0, vin1_opt), zeros), bounds); + MS_STQ_F32(out + index, vout); } #endif for (; index < size; index++) { @@ -228,28 +335,42 @@ int ElementOptMulRelu6(const float *in0, const float *in1, float *out, int size, } int ElementOptMulInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - int32x4_t vin0_opt = vdupq_n_s32(in0[0]); - int32x4_t vin1_opt = vdupq_n_s32(in1[0]); -#endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin1 = vld1q_s32(in1 + index); - int32x4_t vout = vmulq_s32(vin0_opt, vin1); - vst1q_s32(out + index, vout); +#if defined(ENABLE_AVX) + MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]); + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); + MS_INT32X8 vout = MS_MUL256_EPI32(vin0_opt_8, vin1); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]); + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); + MS_INT32X4 vout = MS_MULQ_EPI32(vin0_opt, vin1); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { out[index] = in0[0] * in1[index]; } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin0 = vld1q_s32(in0 + index); - int32x4_t vout = vmulq_s32(vin0, vin1_opt); - vst1q_s32(out + index, vout); +#if defined(ENABLE_AVX) + MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]); + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); + MS_INT32X8 vout = MS_MUL256_EPI32(vin0, vin1_opt_8); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]); + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); + MS_INT32X4 vout = MS_MULQ_EPI32(vin0, vin1_opt); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { @@ -260,29 +381,46 @@ int ElementOptMulInt(const int *in0, const int *in1, int *out, int size, const A } int ElementOptMulReluInt(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - int32x4_t vin0_opt = vdupq_n_s32(in0[0]); - int32x4_t vin1_opt = vdupq_n_s32(in1[0]); - int32x4_t zeros = vdupq_n_s32(0); -#endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin1 = vld1q_s32(in1 + index); - int32x4_t vout = vmaxq_s32(vmulq_s32(vin0_opt, vin1), zeros); - vst1q_s32(out + index, vout); +#if defined(ENABLE_AVX) + MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]); + MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); + MS_INT32X8 vout = MS_MAX256_EPI32(MS_MUL256_EPI32(vin0_opt_8, vin1), zeros_8); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]); + MS_INT32X4 zeros = MS_MOVQ_EPI32(0); + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); + MS_INT32X4 vout = MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0_opt, vin1), zeros); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { out[index] = MSMAX(in0[0] * in1[index], 0); } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin0 = vld1q_s32(in0 + index); - int32x4_t vout = vmaxq_s32(vmulq_s32(vin0, vin1_opt), zeros); - vst1q_s32(out + index, vout); +#if defined(ENABLE_AVX) + MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]); + MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); + MS_INT32X8 vout = MS_MAX256_EPI32(MS_MUL256_EPI32(vin0, vin1_opt_8), zeros_8); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]); + MS_INT32X4 zeros = MS_MOVQ_EPI32(0); + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); + MS_INT32X4 vout = MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0, vin1_opt), zeros); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { @@ -293,30 +431,50 @@ int ElementOptMulReluInt(const int *in0, const int *in1, int *out, int size, con } int ElementOptMulRelu6Int(const int *in0, const int *in1, int *out, int size, const ArithmeticParameter *param) { -#ifdef ENABLE_NEON - int32x4_t vin0_opt = vdupq_n_s32(in0[0]); - int32x4_t vin1_opt = vdupq_n_s32(in1[0]); - int32x4_t zeros = vdupq_n_s32(0); - int32x4_t bounds = vdupq_n_s32(6); -#endif int index = 0; if (param->in_elements_num0_ == 1) { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin1 = vld1q_s32(in1 + index); - int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0_opt, vin1), zeros), bounds); - vst1q_s32(out + index, vout); +#if defined(ENABLE_AVX) + MS_INT32X8 vin0_opt_8 = MS_MOV256_EPI32(in0[0]); + MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); + MS_INT32X8 bounds_8 = MS_MOV256_EPI32(6); + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin1 = MS_LD256_EPI32(in1 + index); + MS_INT32X8 vout = MS_MIN256_EPI32(MS_MAX256_EPI32(MS_MUL256_EPI32(vin0_opt_8, vin1), zeros_8), bounds_8); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_INT32X4 vin0_opt = MS_MOVQ_EPI32(in0[0]); + MS_INT32X4 zeros = MS_MOVQ_EPI32(0); + MS_INT32X4 bounds = MS_MOVQ_EPI32(6); + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin1 = MS_LDQ_EPI32(in1 + index); + MS_INT32X4 vout = MS_MINQ_EPI32(MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0_opt, vin1), zeros), bounds); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { out[index] = MSMIN(MSMAX(in0[0] * in1[index], 0), 6); } } else { -#ifdef ENABLE_NEON - for (; index <= size - 4; index += C4NUM) { - int32x4_t vin0 = vld1q_s32(in0 + index); - int32x4_t vout = vminq_s32(vmaxq_s32(vmulq_s32(vin0, vin1_opt), zeros), bounds); - vst1q_s32(out + index, vout); +#if defined(ENABLE_AVX) + MS_INT32X8 vin1_opt_8 = MS_MOV256_EPI32(in1[0]); + MS_INT32X8 zeros_8 = MS_MOV256_EPI32(0); + MS_INT32X8 bounds_8 = MS_MOV256_EPI32(6); + for (; index <= size - C8NUM; index += C8NUM) { + MS_INT32X8 vin0 = MS_LD256_EPI32(in0 + index); + MS_INT32X8 vout = MS_MIN256_EPI32(MS_MAX256_EPI32(MS_MUL256_EPI32(vin0, vin1_opt_8), zeros_8), bounds_8); + MS_ST256_EPI32(out + index, vout); + } +#endif +#if defined(ENABLE_NEON) || defined(ENABLE_SSE) + MS_INT32X4 vin1_opt = MS_MOVQ_EPI32(in1[0]); + MS_INT32X4 zeros = MS_MOVQ_EPI32(0); + MS_INT32X4 bounds = MS_MOVQ_EPI32(6); + for (; index <= size - C4NUM; index += C4NUM) { + MS_INT32X4 vin0 = MS_LDQ_EPI32(in0 + index); + MS_INT32X4 vout = MS_MINQ_EPI32(MS_MAXQ_EPI32(MS_MULQ_EPI32(vin0, vin1_opt), zeros), bounds); + MS_STQ_EPI32(out + index, vout); } #endif for (; index < size; index++) { diff --git a/mindspore/lite/nnacl/fp32/winograd_utils.c b/mindspore/lite/nnacl/fp32/winograd_utils.c index 941920a852..7582e5c98d 100644 --- a/mindspore/lite/nnacl/fp32/winograd_utils.c +++ b/mindspore/lite/nnacl/fp32/winograd_utils.c @@ -237,28 +237,30 @@ void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int offset = l * 6; MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(src[3 + offset], src[1 + offset]); MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(src[4 + offset], src[2 + offset]); - t[l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(src[offset], 4), MS_MULQ_F32(src[2 + offset], 5)), src[4 + offset]); - t[6 + l] = MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(src[1 + offset], src[2 + offset]), -4), + t[l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[offset], 4), MS_MULQ_N_F32(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(src[1 + offset], src[2 + offset]), -4), MS_ADDQ_F32(src[3 + offset], src[4 + offset])); - t[12 + l] = MS_ADDQ_F32(MS_MULQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), 4), + t[12 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), 4), MS_SUBQ_F32(src[4 + offset], src[3 + offset])); - t[18 + l] = MS_ADDQ_F32(MS_MULQ_F32(tmp1, 2), tmp2); - t[24 + l] = MS_ADDQ_F32(MS_MULQ_F32(tmp1, -2), tmp2); + t[18 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + t[24 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); t[30 + l] = - MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(src[1 + offset], 4), MS_MULQ_F32(src[3 + offset], 5)), src[5 + offset]); + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[1 + offset], 4), MS_MULQ_N_F32(src[3 + offset], 5)), src[5 + offset]); } for (int l = 0; l < 6; ++l) { int offset = l * 6; MS_FLOAT32X4 tmp1 = MS_SUBQ_F32(t[3 + offset], t[1 + offset]); MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(t[4 + offset], t[2 + offset]); - m[l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(t[offset], 4), MS_MULQ_F32(t[2 + offset], 5)), t[4 + offset]); - m[6 + l] = MS_ADDQ_F32(MS_MULQ_F32(MS_ADDQ_F32(t[1 + offset], t[2 + offset]), -4), + m[l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[offset], 4), MS_MULQ_N_F32(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_ADDQ_F32(t[1 + offset], t[2 + offset]), -4), MS_ADDQ_F32(t[3 + offset], t[4 + offset])); - m[12 + l] = MS_ADDQ_F32(MS_MULQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), 4), + m[12 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), 4), MS_SUBQ_F32(t[4 + offset], t[3 + offset])); - m[18 + l] = MS_ADDQ_F32(MS_MULQ_F32(tmp1, 2), tmp2); - m[24 + l] = MS_ADDQ_F32(MS_MULQ_F32(tmp1, -2), tmp2); - m[30 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(t[1 + offset], 4), MS_MULQ_F32(t[3 + offset], 5)), t[5 + offset]); + m[18 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 2), tmp2); + m[24 + l] = MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, -2), tmp2); + m[30 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[1 + offset], 4), MS_MULQ_N_F32(t[3 + offset], 5)), t[5 + offset]); } for (int i = 0; i < 36; i++) { MS_STQ_F32(dst_data + i * dst_step, m[i]); @@ -311,46 +313,51 @@ void InputTransform8x8Unit_block4(const float *src_data, float *dst_data, int sr Load64Data; for (int l = 0; l < 8; ++l) { int offset = l * 8; - t[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(src[offset], 0.5625), MS_MULQ_F32(src[2 + offset], 3.0625)), - MS_MULQ_F32(src[4 + offset], 3.5)), - src[6 + offset]); - MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 1.125), MS_MULQ_F32(src[5 + offset], 0.5)); - MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 2.25), MS_MULQ_F32(src[4 + offset], 3.25)); - t[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); - t[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.625)), src[6 + offset]); - tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.5625), src[5 + offset]); - tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.5625), MS_MULQ_F32(src[4 + offset], 2.5)); - t[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); - t[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 2.5)), src[6 + offset]); - tmp1 = MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], 0.375), MS_MULQ_F32(src[5 + offset], 1.5)); - tmp2 = MS_SUBQ_F32(MS_MULQ_F32(src[2 + offset], 0.25), MS_MULQ_F32(src[4 + offset], 1.25)); - t[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); - t[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(src[3 + offset], 1.875)), src[6 + offset]); - t[56 + l] = - MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(src[1 + offset], -0.5625), MS_MULQ_F32(src[3 + offset], 3.0625)), - MS_MULQ_F32(src[5 + offset], 3.5)), - src[7 + offset]); + t[l] = + MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(src[offset], 0.5625), MS_MULQ_N_F32(src[2 + offset], 3.0625)), + MS_MULQ_N_F32(src[4 + offset], 3.5)), + src[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 1.125), MS_MULQ_N_F32(src[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 2.25), MS_MULQ_N_F32(src[4 + offset], 3.25)); + t[8 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 0.5625), MS_MULQ_N_F32(src[4 + offset], 2.5)); + t[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], 0.375), MS_MULQ_N_F32(src[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(src[2 + offset], 0.25), MS_MULQ_N_F32(src[4 + offset], 1.25)); + t[40 + l] = + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = MS_ADDQ_F32( + MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(src[1 + offset], -0.5625), MS_MULQ_N_F32(src[3 + offset], 3.0625)), + MS_MULQ_N_F32(src[5 + offset], 3.5)), + src[7 + offset]); } for (int l = 0; l < 8; ++l) { int offset = l * 8; - m[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_F32(t[offset], 0.5625), MS_MULQ_F32(t[2 + offset], 3.0625)), - MS_MULQ_F32(t[4 + offset], 3.5)), + m[l] = MS_SUBQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(MS_MULQ_N_F32(t[offset], 0.5625), MS_MULQ_N_F32(t[2 + offset], 3.0625)), + MS_MULQ_N_F32(t[4 + offset], 3.5)), t[6 + offset]); - MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 1.125), MS_MULQ_F32(t[5 + offset], 0.5)); - MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 2.25), MS_MULQ_F32(t[4 + offset], 3.25)); - m[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); - m[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.625)), t[6 + offset]); - tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.5625), t[5 + offset]); - tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.5625), MS_MULQ_F32(t[4 + offset], 2.5)); - m[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); - m[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 2.5)), t[6 + offset]); - tmp1 = MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], 0.375), MS_MULQ_F32(t[5 + offset], 1.5)); - tmp2 = MS_SUBQ_F32(MS_MULQ_F32(t[2 + offset], 0.25), MS_MULQ_F32(t[4 + offset], 1.25)); - m[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); - m[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_F32(t[3 + offset], 1.875)), t[6 + offset]); + MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 1.125), MS_MULQ_N_F32(t[5 + offset], 0.5)); + MS_FLOAT32X4 tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 2.25), MS_MULQ_N_F32(t[4 + offset], 3.25)); + m[8 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 0.5625), MS_MULQ_N_F32(t[4 + offset], 2.5)); + m[24 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], 0.375), MS_MULQ_N_F32(t[5 + offset], 1.5)); + tmp2 = MS_SUBQ_F32(MS_MULQ_N_F32(t[2 + offset], 0.25), MS_MULQ_N_F32(t[4 + offset], 1.25)); + m[40 + l] = MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(tmp1, tmp2), MS_MULQ_N_F32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(tmp2, tmp1), MS_MULQ_N_F32(t[3 + offset], 1.875)), t[6 + offset]); m[56 + l] = - MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_F32(t[1 + offset], -0.5625), MS_MULQ_F32(t[3 + offset], 3.0625)), - MS_MULQ_F32(t[5 + offset], 3.5)), + MS_ADDQ_F32(MS_SUBQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(t[1 + offset], -0.5625), MS_MULQ_N_F32(t[3 + offset], 3.0625)), + MS_MULQ_N_F32(t[5 + offset], 3.5)), t[7 + offset]); } for (int i = 0; i < 64; i++) { @@ -880,7 +887,7 @@ void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), src[4 + offset]); t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), - MS_MULQ_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), src[5 + offset]); } for (int l = 0; l < 2; ++l) { @@ -890,7 +897,7 @@ void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float t[4 + offset]), bias_ptr); m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), - MS_MULQ_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), t[5 + offset]), bias_ptr); } @@ -953,7 +960,7 @@ void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const fl MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), src[4 + offset]); t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), - MS_MULQ_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), src[5 + offset]); } for (int l = 0; l < 2; ++l) { @@ -963,7 +970,7 @@ void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const fl t[4 + offset]), bias_ptr); m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), - MS_MULQ_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), t[5 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); @@ -1031,7 +1038,7 @@ void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const f MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), src[4 + offset]); t[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), - MS_MULQ_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)), src[5 + offset]); } for (int l = 0; l < 2; ++l) { @@ -1041,7 +1048,7 @@ void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const f t[4 + offset]), bias_ptr); m[l + 2] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), - MS_MULQ_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), t[5 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); @@ -1110,18 +1117,18 @@ void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), - MS_MULQ_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); - t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), src[5 + offset]); + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); } for (int l = 0; l < 3; ++l) { int offset = l * 6; MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 3] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), MS_MULQ_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), - bias_ptr); - m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); } if (r_c == C4NUM && r_h == 3 && r_w == 3) { Store9Data; @@ -1184,18 +1191,18 @@ void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), - MS_MULQ_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); - t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), src[5 + offset]); + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); } for (int l = 0; l < 3; ++l) { int offset = l * 6; MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 3] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), MS_MULQ_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), - bias_ptr); - m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); @@ -1264,18 +1271,18 @@ void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(src[3 + offset], src[4 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); t[l + 6] = MS_ADDQ_F32(MS_SUBQ_F32(src[1 + offset], src[2 + offset]), - MS_MULQ_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); - t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), src[5 + offset]); + MS_MULQ_N_F32(MS_SUBQ_F32(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), src[5 + offset]); } for (int l = 0; l < 3; ++l) { int offset = l * 6; MS_FLOAT32X4 tmp1 = MS_ADDQ_F32(t[1 + offset], t[2 + offset]); MS_FLOAT32X4 tmp2 = MS_ADDQ_F32(t[3 + offset], t[4 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 3] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), MS_MULQ_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), - bias_ptr); - m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_SUBQ_F32(t[1 + offset], t[2 + offset]), + MS_MULQ_N_F32(MS_SUBQ_F32(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), t[5 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l] = MS_MINQ_F32(six, m[l]); m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); @@ -1347,9 +1354,9 @@ void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); - t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)); - t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)); - t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)), src[5 + offset]); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); } for (int l = 0; l < 4; ++l) { int offset = l * 6; @@ -1358,9 +1365,9 @@ void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)), bias_ptr); - m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), bias_ptr); - m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); } if (r_c == C4NUM && r_h == 4 && r_w == 4) { Store16Data; @@ -1426,9 +1433,9 @@ void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); - t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)); - t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)); - t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)), src[5 + offset]); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); } for (int l = 0; l < 4; ++l) { int offset = l * 6; @@ -1437,9 +1444,9 @@ void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)), bias_ptr); - m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), bias_ptr); - m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); @@ -1512,9 +1519,9 @@ void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); - t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)); - t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)); - t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)), src[5 + offset]); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), src[5 + offset]); } for (int l = 0; l < 4; ++l) { int offset = l * 6; @@ -1523,9 +1530,9 @@ void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)), bias_ptr); - m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), bias_ptr); - m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), t[5 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l] = MS_MINQ_F32(six, m[l]); m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); @@ -1601,10 +1608,10 @@ void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); - t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)); - t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)); - t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 16)), src[5 + offset]); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); } for (int l = 0; l < 5; ++l) { int offset = l * 6; @@ -1613,10 +1620,10 @@ void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)), bias_ptr); - m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), bias_ptr); - m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)), bias_ptr); - m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); } if (r_c == C4NUM && r_h == 5 && r_w == 5) { Store25Data; @@ -1684,10 +1691,10 @@ void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); - t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)); - t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)); - t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 16)), src[5 + offset]); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); } for (int l = 0; l < 5; ++l) { int offset = l * 6; @@ -1696,10 +1703,10 @@ void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)), bias_ptr); - m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), bias_ptr); - m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)), bias_ptr); - m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); @@ -1775,10 +1782,10 @@ void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(src[1 + offset], src[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2); - t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)); - t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)); - t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 16)), src[5 + offset]); + t[l + 6] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)); + t[l + 12] = MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)); + t[l + 18] = MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), src[5 + offset]); } for (int l = 0; l < 5; ++l) { int offset = l * 6; @@ -1787,10 +1794,10 @@ void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp3 = MS_SUBQ_F32(t[1 + offset], t[2 + offset]); MS_FLOAT32X4 tmp4 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), bias_ptr); - m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 2)), bias_ptr); - m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 4)), bias_ptr); - m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_F32(tmp4, 8)), bias_ptr); - m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_F32(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 2)), bias_ptr); + m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 4)), bias_ptr); + m[l + 15] = MS_ADDQ_F32(MS_ADDQ_F32(tmp3, MS_MULQ_N_F32(tmp4, 8)), bias_ptr); + m[l + 20] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(tmp1, MS_MULQ_N_F32(tmp2, 16)), t[5 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l] = MS_MINQ_F32(six, m[l]); m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); @@ -1873,7 +1880,7 @@ void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); t[l + 8] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), src[7 + offset]); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); } for (int l = 0; l < 2; ++l) { int offset = l * 8; @@ -1885,7 +1892,7 @@ void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); m[l + 2] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), t[7 + offset]), + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), bias_ptr); } if (r_c == C4NUM && r_h == 2 && r_w == 2) { @@ -1954,7 +1961,7 @@ void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); t[l + 8] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), src[7 + offset]); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); } for (int l = 0; l < 2; ++l) { int offset = l * 8; @@ -1966,7 +1973,7 @@ void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); m[l + 2] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), t[7 + offset]), + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l + 2] = MS_MAXQ_F32(zero, m[l + 2]); @@ -2040,7 +2047,7 @@ void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); t[l + 8] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), src[7 + offset]); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), src[7 + offset]); } for (int l = 0; l < 2; ++l) { int offset = l * 8; @@ -2052,7 +2059,7 @@ void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); m[l + 2] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), t[7 + offset]), + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), t[7 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l] = MS_MINQ_F32(six, m[l]); @@ -2126,9 +2133,9 @@ void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), src[7 + offset]); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); } for (int l = 0; l < 3; ++l) { int offset = l * 8; @@ -2139,9 +2146,10 @@ void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); m[l + 6] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), t[7 + offset]), + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), bias_ptr); } if (r_c == C4NUM && r_h == 3 && r_w == 3) { @@ -2213,9 +2221,9 @@ void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), src[7 + offset]); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); } for (int l = 0; l < 3; ++l) { int offset = l * 8; @@ -2226,9 +2234,10 @@ void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); m[l + 6] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), t[7 + offset]), + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l + 3] = MS_MAXQ_F32(zero, m[l + 3]); @@ -2306,9 +2315,9 @@ void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), src[7 + offset]); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), + src[7 + offset]); } for (int l = 0; l < 3; ++l) { int offset = l * 8; @@ -2319,9 +2328,10 @@ void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 3] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); + m[l + 3] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); m[l + 6] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), t[7 + offset]), + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), t[7 + offset]), bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l] = MS_MINQ_F32(six, m[l]); @@ -2401,10 +2411,10 @@ void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), src[7 + offset]); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); } for (int l = 0; l < 4; ++l) { int offset = l * 8; @@ -2415,11 +2425,14 @@ void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); - m[l + 12] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), t[7 + offset]), - bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); } if (r_c == C4NUM && r_h == 4 && r_w == 4) { Store16Data; @@ -2494,10 +2507,10 @@ void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), src[7 + offset]); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); } for (int l = 0; l < 4; ++l) { int offset = l * 8; @@ -2508,11 +2521,14 @@ void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); - m[l + 12] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), t[7 + offset]), - bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); m[l + 8] = MS_MAXQ_F32(zero, m[l + 8]); @@ -2575,9 +2591,9 @@ void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const fl #endif } +#if defined(ENABLE_ARM) || defined(ENABLE_SSE) void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) MS_FLOAT32X4 src[64]; MS_FLOAT32X4 t[32]; MS_FLOAT32X4 m[16]; @@ -2594,10 +2610,10 @@ void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), src[7 + offset]); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + src[7 + offset]); } for (int l = 0; l < 4; ++l) { int offset = l * 8; @@ -2608,11 +2624,14 @@ void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 4] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); - m[l + 12] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), t[7 + offset]), - bias_ptr); + m[l + 4] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 8] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), + t[7 + offset]), + bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l] = MS_MINQ_F32(six, m[l]); m[l + 4] = MS_MAXQ_F32(zero, m[l + 4]); @@ -2635,7 +2654,10 @@ void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const f } } } +} #else +void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, + int dst_step, int out_c, int r_w, int r_h, int r_c) { float src[64]; float t[32]; float m[16]; @@ -2677,8 +2699,8 @@ void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const f } } } -#endif } +#endif void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { @@ -2697,10 +2719,10 @@ void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)); - t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), src[7 + offset]); } for (int l = 0; l < 5; ++l) { @@ -2712,13 +2734,16 @@ void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); m[l + 15] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), bias_ptr); - m[l + 20] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), t[7 + offset]), - bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); } if (r_c == C4NUM && r_h == 5 && r_w == 5) { Store25Data; @@ -2797,10 +2822,10 @@ void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)); - t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), src[7 + offset]); } for (int l = 0; l < 5; ++l) { @@ -2812,13 +2837,16 @@ void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); m[l + 15] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), bias_ptr); - m[l + 20] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), t[7 + offset]), - bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); m[l + 10] = MS_MAXQ_F32(zero, m[l + 10]); @@ -2908,10 +2936,10 @@ void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)); - t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), src[7 + offset]); } for (int l = 0; l < 5; ++l) { @@ -2923,13 +2951,16 @@ void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 5] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 10] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); + m[l + 5] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 10] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); m[l + 15] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), bias_ptr); - m[l + 20] = MS_ADDQ_F32( - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), t[7 + offset]), - bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); + m[l + 20] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), + t[7 + offset]), + bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l] = MS_MINQ_F32(six, m[l]); m[l + 5] = MS_MAXQ_F32(zero, m[l + 5]); @@ -3023,11 +3054,11 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)); - t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)); - t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)), + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), src[7 + offset]); } for (int l = 0; l < 6; ++l) { @@ -3039,16 +3070,18 @@ void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); m[l + 18] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); m[l + 24] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), bias_ptr); - m[l + 30] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)), - t[7 + offset]), - bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); } if (r_c == C4NUM && r_h == 6 && r_w == 6) { for (int i = 0; i < 6; i++) { @@ -3143,11 +3176,11 @@ void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)); - t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)); - t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)), + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), src[7 + offset]); } for (int l = 0; l < 6; ++l) { @@ -3159,16 +3192,18 @@ void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); m[l + 18] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); m[l + 24] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), bias_ptr); - m[l + 30] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)), - t[7 + offset]), - bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); m[l + 12] = MS_MAXQ_F32(zero, m[l + 12]); @@ -3272,11 +3307,11 @@ void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)); - t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)); - t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)), + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), src[7 + offset]); } for (int l = 0; l < 6; ++l) { @@ -3288,16 +3323,18 @@ void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 6] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 12] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); + m[l + 6] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 12] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); m[l + 18] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); m[l + 24] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), bias_ptr); - m[l + 30] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)), - t[7 + offset]), - bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); + m[l + 30] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), + t[7 + offset]), + bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l] = MS_MINQ_F32(six, m[l]); m[l + 6] = MS_MAXQ_F32(zero, m[l + 6]); @@ -3406,13 +3443,13 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)); - t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)); - t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)); - t[l + 48] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.015625), tmp2), MS_MULQ_F32(tmp3, 11.390625)), - src[7 + offset]); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); } for (int l = 0; l < 7; ++l) { int offset = l * 8; @@ -3423,18 +3460,20 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 7] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 14] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); m[l + 21] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); m[l + 28] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); m[l + 35] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)), bias_ptr); - m[l + 42] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.015625), tmp2), MS_MULQ_F32(tmp3, 11.390625)), - t[7 + offset]), - bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); } if (r_c == C4NUM && r_h == 7 && r_w == 7) { for (int i = 0; i < 7; i++) { @@ -3534,13 +3573,13 @@ void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)); - t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)); - t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)); - t[l + 48] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.015625), tmp2), MS_MULQ_F32(tmp3, 11.390625)), - src[7 + offset]); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); } for (int l = 0; l < 7; ++l) { int offset = l * 8; @@ -3551,18 +3590,20 @@ void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const fl MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 7] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 14] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); m[l + 21] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); m[l + 28] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); m[l + 35] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)), bias_ptr); - m[l + 42] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.015625), tmp2), MS_MULQ_F32(tmp3, 11.390625)), - t[7 + offset]), - bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l + 7] = MS_MAXQ_F32(zero, m[l + 7]); m[l + 14] = MS_MAXQ_F32(zero, m[l + 14]); @@ -3672,13 +3713,13 @@ void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(src[3 + offset], src[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(src[5 + offset], src[6 + offset]); t[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)); - t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)); - t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)); - t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)); - t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)); - t[l + 48] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.015625), tmp2), MS_MULQ_F32(tmp3, 11.390625)), - src[7 + offset]); + t[l + 8] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)); + t[l + 16] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)); + t[l + 24] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)); + t[l + 32] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)); + t[l + 40] = MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)); + t[l + 48] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), src[7 + offset]); } for (int l = 0; l < 7; ++l) { int offset = l * 8; @@ -3689,18 +3730,20 @@ void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const f MS_FLOAT32X4 tmp5 = MS_SUBQ_F32(t[3 + offset], t[4 + offset]); MS_FLOAT32X4 tmp6 = MS_SUBQ_F32(t[5 + offset], t[6 + offset]); m[l] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 7] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.5), tmp5), MS_MULQ_F32(tmp6, 1.5)), bias_ptr); - m[l + 14] = MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.25), tmp2), MS_MULQ_F32(tmp3, 2.25)), bias_ptr); + m[l + 7] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.5), tmp5), MS_MULQ_N_F32(tmp6, 1.5)), bias_ptr); + m[l + 14] = + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.25), tmp2), MS_MULQ_N_F32(tmp3, 2.25)), bias_ptr); m[l + 21] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.125), tmp5), MS_MULQ_F32(tmp6, 3.375)), bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.125), tmp5), MS_MULQ_N_F32(tmp6, 3.375)), bias_ptr); m[l + 28] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.0625), tmp2), MS_MULQ_F32(tmp3, 5.0625)), bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.0625), tmp2), MS_MULQ_N_F32(tmp3, 5.0625)), bias_ptr); m[l + 35] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp4, 0.03125), tmp5), MS_MULQ_F32(tmp6, 7.59375)), bias_ptr); - m[l + 42] = - MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_F32(tmp1, 0.015625), tmp2), MS_MULQ_F32(tmp3, 11.390625)), - t[7 + offset]), - bias_ptr); + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp4, 0.03125), tmp5), MS_MULQ_N_F32(tmp6, 7.59375)), bias_ptr); + m[l + 42] = MS_ADDQ_F32( + MS_ADDQ_F32(MS_ADDQ_F32(MS_ADDQ_F32(MS_MULQ_N_F32(tmp1, 0.015625), tmp2), MS_MULQ_N_F32(tmp3, 11.390625)), + t[7 + offset]), + bias_ptr); m[l] = MS_MAXQ_F32(zero, m[l]); m[l] = MS_MINQ_F32(six, m[l]); m[l + 7] = MS_MAXQ_F32(zero, m[l + 7]); diff --git a/mindspore/lite/nnacl/op_base.h b/mindspore/lite/nnacl/op_base.h index 1ea2d711b0..007debd86e 100644 --- a/mindspore/lite/nnacl/op_base.h +++ b/mindspore/lite/nnacl/op_base.h @@ -107,28 +107,105 @@ typedef enum CalFixedMultiplierMode { #ifdef ENABLE_ARM #define MS_FLOAT32X4 float32x4_t +#define MS_INT32X4 int32x4_t #define MS_LDQ_F32 vld1q_f32 +#define MS_LDQ_EPI32 vld1q_s32 #define MS_ADDQ_F32 vaddq_f32 +#define MS_ADDQ_EPI32 vaddq_s32 #define MS_MOVQ_F32 vmovq_n_f32 +#define MS_MOVQ_EPI32 vmovq_n_s32 #define MS_DUPQ_F32 vdupq_n_f32 // It is recommended to replace with MS_MOVQ_F32. #define MS_SUBQ_F32 vsubq_f32 #define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) #define MS_STQ_F32 vst1q_f32 +#define MS_STQ_EPI32 vst1q_s32 #define MS_MAXQ_F32 vmaxq_f32 +#define MS_MAXQ_EPI32 vmaxq_s32 #define MS_MINQ_F32 vminq_f32 -#define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2) -#elif defined(ENABLE_SSE) +#define MS_MINQ_EPI32 vminq_s32 +#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2) +#define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2) +#ifdef ENABLE_ARM64 +#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) +#else +#define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecpeq_f32(src2)) +#endif +#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2) +#define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2) +#define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2) +#define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2)) +#define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src) +#define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src) +#define MS_CMPGTQ_PS(src1, src2) vcgtq_f32(src1, src2) +#define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2) +// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_ps +#define MS_BLENDQ_PS(src1, src2, src3) vbslq_f32(src3, src2, src1) +#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1) + +#elif defined(ENABLE_AVX) +#define MS_FLOAT32X8 __m256 +#define MS_INT32X8 __m256i +#define MS_LD256_F32 _mm256_loadu_ps +#define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src)) +#define MS_ADD256_F32 _mm256_add_ps +#define MS_ADD256_EPI32 _mm256_add_epi32 +#define MS_MOV256_F32 _mm256_set1_ps +#define MS_MOV256_EPI32 _mm256_set1_epi32 +#define MS_DUP256_F32 _mm256_load_ps1 // It is recommended to replace with MS_MOV256_F32. +#define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3)) +#define MS_ST256_F32 _mm256_storeu_ps +#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) +#define MS_SUB256_F32 _mm256_sub_ps +#define MS_MAX256_F32 _mm256_max_ps +#define MS_MAX256_EPI32 _mm256_max_epi32 +#define MS_MIN256_F32 _mm256_min_ps +#define MS_MIN256_EPI32 _mm256_min_epi32 +#define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2) +#define MS_MUL256_EPI32(src1, src2) _mm256_mul_epi32(src1, src2) +#define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2) +#define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2)) +#define MS_MUL256_N_EPI32(src1, src2) _mm256_mul_epi32(src1, _mm256_set1_epi32(src2)) +#define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2)) +#define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2) +#define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src) +#define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int +#define MS_CMP256_PS(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3) +#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2) +#define MS_BLEND256_PS(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3) +#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3) +#endif + +#if defined(ENABLE_SSE) #define MS_FLOAT32X4 __m128 +#define MS_INT32X4 __m128i #define MS_LDQ_F32 _mm_loadu_ps +#define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src)) #define MS_ADDQ_F32 _mm_add_ps -#define MS_MOVQ_F32 _mm_set_ps1 +#define MS_ADDQ_EPI32 _mm_add_epi32 +#define MS_MOVQ_F32 _mm_set1_ps +#define MS_MOVQ_EPI32 _mm_set1_epi32 #define MS_DUPQ_F32 _mm_load_ps1 // It is recommended to replace with MS_MOVQ_F32. #define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3)) #define MS_STQ_F32 _mm_storeu_ps +#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) #define MS_SUBQ_F32 _mm_sub_ps #define MS_MAXQ_F32 _mm_max_ps +#define MS_MAXQ_EPI32 _mm_max_epi32 #define MS_MINQ_F32 _mm_min_ps -#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, _mm_set_ps1(src2)) +#define MS_MINQ_EPI32 _mm_min_epi32 +#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2) +#define MS_MULQ_EPI32(src1, src2) _mm_mul_epi32(src1, src2) +#define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2) +#define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2)) +#define MS_MULQ_N_EPI32(src1, src2) _mm_mul_epi32(src1, _mm_set1_epi32(src2)) +#define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2)) +#define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2) +#define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int +#define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src) +#define MS_CMPGTQ_PS(src1, src2) _mm_cmpgt_ps(src1, src2) +#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2) +#define MS_BLENDQ_PS(src1, src2, src3) _mm_blendv_ps(src1, src2, src3) +#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3) #endif #endif // MINDSPORE_LITE_NNACL_OP_BASE_H_ diff --git a/mindspore/lite/src/ops/tensorlist_getitem.cc b/mindspore/lite/src/ops/tensorlist_getitem.cc index 6299e3c713..01cb7e1592 100644 --- a/mindspore/lite/src/ops/tensorlist_getitem.cc +++ b/mindspore/lite/src/ops/tensorlist_getitem.cc @@ -179,7 +179,7 @@ int TensorListGetItem::InferShape(std::vector inputs_, std::vect MS_LOG(ERROR) << "element_shape_ is not fullyDefined!"; return RET_ERROR; } - output->set_data_type(input0->data_type()); + output->set_data_type(input0->tensors_data_type()); output->set_shape(element_shape_); } output->set_format(input0->GetTensor(index_)->format());