From e4fa3c5b855fe2052b35728af39a28b14c66b22a Mon Sep 17 00:00:00 2001 From: liuzhongkai Date: Tue, 3 Nov 2020 14:06:36 +0800 Subject: [PATCH] add x86_64 sse optimize --- build.sh | 18 +- mindspore/lite/CMakeLists.txt | 7 + mindspore/lite/nnacl/CMakeLists.txt | 5 + .../lite/nnacl/minimal_filtering_generator.c | 2 +- mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c | 747 ++++++++++++++++++ mindspore/lite/test/CMakeLists.txt | 10 + 6 files changed, 785 insertions(+), 4 deletions(-) create mode 100644 mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c diff --git a/build.sh b/build.sh index a2b3a5ca92..65b74a0eaf 100755 --- a/build.sh +++ b/build.sh @@ -26,7 +26,7 @@ usage() echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\" echo " [-B on|off] [-w on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\" - echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] \\" + echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|off] \\" echo "" echo "Options:" echo " -d Debug mode" @@ -65,6 +65,7 @@ usage() echo " -o Enable mindspore lite tools compilation, enabled when -I is specified, default on" echo " -S Enable enable download cmake compile dependency from gitee , default off" echo " -k Enable make clean, clean up compilation generated cache " + echo " -W Enable x86_64 SSE or AVX instruction set, use [sse|avx|neon|off], default off" } # check value of input is 'on' or 'off' @@ -118,9 +119,10 @@ checkopts() ENABLE_GITEE="off" ANDROID_STL="c++_shared" ENABLE_MAKE_CLEAN="off" + X86_64_SIMD="off" # Process the options - while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:D:zM:V:K:swB:En:T:A:C:o:S:k:' opt + while getopts 'drvj:c:t:hsb:a:g:p:ie:m:l:I:LRP:D:zM:V:K:swB:En:T:A:C:o:S:k:W:' opt do OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in @@ -341,6 +343,16 @@ checkopts() check_on_off $OPTARG o ENABLE_TOOLS="$OPTARG" ;; + W) + if [[ "$OPTARG" != "sse" && "$OPTARG" != "off" && "$OPTARG" != "avx" && "$OPTARG" != "neon" ]]; then + echo "Invalid value ${OPTARG} for option -W, -W parameter must be sse|neon|avx|off" + usage + exit 1 + fi + if [[ "$OPTARG" == "sse" || "$OPTARG" == "avx" ]]; then + X86_64_SIMD="$OPTARG" + fi + ;; *) echo "Unknown option ${opt}!" usage @@ -702,7 +714,7 @@ build_lite() -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \ -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ - -DENABLE_VERBOSE=${ENABLE_VERBOSE} "${BASEPATH}/mindspore/lite" + -DENABLE_VERBOSE=${ENABLE_VERBOSE} -DX86_64_SIMD=${X86_64_SIMD} "${BASEPATH}/mindspore/lite" fi make -j$THREAD_NUM && make install && make package COMPILE_RET=$? diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 67cc5c0ae3..4290b43741 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -20,6 +20,7 @@ option(SUPPORT_GPU "if support gpu" off) option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off) option(BUILD_MINDDATA_EXAMPLE "" on) option(ENABLE_VERBOSE "" off) +option(ENABLE_X86_64_SSE "if x86_64 support SSE instruction set" off) set(DIR_PREFIX mindspore-lite) set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION}) @@ -174,6 +175,12 @@ if (PLATFORM_ARM32 OR PLATFORM_ARM64) endif() endif() +if (NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64) + if ("${X86_64_SIMD}" STREQUAL "sse") + add_compile_definitions(ENABLE_X86_64_SSE) + endif () +endif () + if (BUILD_MINDDATA STREQUAL "lite" OR BUILD_MINDDATA STREQUAL "full") # add sentencepiece dependency # include(${TOP_DIR}/cmake/external_libs/sentencepiece.cmake) diff --git a/mindspore/lite/nnacl/CMakeLists.txt b/mindspore/lite/nnacl/CMakeLists.txt index 9036cf98b3..df24b90f8e 100644 --- a/mindspore/lite/nnacl/CMakeLists.txt +++ b/mindspore/lite/nnacl/CMakeLists.txt @@ -32,6 +32,11 @@ if (PLATFORM_ARM32) set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) endif() +if ("${X86_64_SIMD}" STREQUAL "sse") + file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c) + set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C) +endif() + ########################### build nnacl static library ######################## string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") add_library(nnacl STATIC ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC}) diff --git a/mindspore/lite/nnacl/minimal_filtering_generator.c b/mindspore/lite/nnacl/minimal_filtering_generator.c index ddc0901380..d9a4b3094c 100644 --- a/mindspore/lite/nnacl/minimal_filtering_generator.c +++ b/mindspore/lite/nnacl/minimal_filtering_generator.c @@ -121,7 +121,7 @@ int B(const float *poly_array, float *matrix_b, int in_unit) { return NNACL_OK; } -#ifndef ENABLE_ARM +#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE) void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, int in_channel, int c4_channel) { int cnt = 0; diff --git a/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c b/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c new file mode 100644 index 0000000000..0dff319548 --- /dev/null +++ b/mindspore/lite/nnacl/x86_64_sse/MatMul_Sse.c @@ -0,0 +1,747 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef ENABLE_X86_64_SSE +#include +#include "nnacl/minimal_filtering_generator.h" +#include "nnacl/op_base.h" + +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel) { + const float *src1 = matix_a; + int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM; + int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM; + for (int i = 0; i < m; ++i) { + const float *src1_n = src1; + const float *src2_n = matrix_b; + for (int j = 0; j < n; ++j) { + const float *src1_j = src1_n; + int y = 0; + // 16 channel + for (; y < c16; y += C16NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(); + __m128 dst4 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + __m128 ma3 = _mm_loadu_ps(src1_j + 8); + __m128 ma4 = _mm_loadu_ps(src1_j + 12); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + __m128 tmp3 = _mm_mul_ps(ma3, mb); + __m128 tmp4 = _mm_mul_ps(ma4, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + dst3 = _mm_add_ps(dst3, tmp3); + dst4 = _mm_add_ps(dst4, tmp4); + src1_j += in_channel; + src2_y += n; + } + _mm_store_ps(matrix_c, dst1); + _mm_store_ps(matrix_c + 4, dst2); + _mm_store_ps(matrix_c + 8, dst3); + _mm_store_ps(matrix_c + 12, dst4); + src1_j -= in_channel * k; + src1_j += C16NUM; + matrix_c += C16NUM; + } + // 8 channel + for (; y < c8; y += C8NUM) { + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + const float *src2_y = src2_n; + for (int z = 0; z < k; ++z) { + __m128 ma1 = _mm_loadu_ps(src1_j); + __m128 ma2 = _mm_loadu_ps(src1_j + 4); + + __m128 mb = _mm_load_ps1(src2_y); + __m128 tmp1 = _mm_mul_ps(ma1, mb); + __m128 tmp2 = _mm_mul_ps(ma2, mb); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + src1_j += in_channel; + src2_y += n; + } + _mm_store_ps(matrix_c, dst1); + _mm_store_ps(matrix_c + 4, dst2); + src1_j -= in_channel * k; + src1_j += C8NUM; + matrix_c += C8NUM; + } + // remain chann + for (; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + *matrix_c++ = tmp; + } + src2_n += 1; + } + src1 += k * in_channel; + } +} + +void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, int write_mode) { + int C8Steps = row * C8NUM; + int WinoSteps1 = stride * col; + int WinoSteps2 = stride * C8NUM; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b; + const float *bias_d = bias; + float *dst = NULL; + for (int cc = col; cc > 0; cc -= C8NUM) { + if (write_mode != 0) { // writec8 + dst = c; + } + const float *srca_d = a; + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(); + __m128 dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(); + __m128 dst6 = _mm_setzero_ps(); + __m128 dst7 = _mm_setzero_ps(); + __m128 dst8 = _mm_setzero_ps(); + for (int d = depth; d > 0; --d) { + __m128 b1 = _mm_loadu_ps(srcb_d); + __m128 b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d); + __m128 a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1); + __m128 tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2); + __m128 tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3); + dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1); + tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2); + tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1); + dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3); + dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM; + srca_d += C4NUM; + } + if (bias != NULL) { + __m128 bias1 = _mm_loadu_ps(bias_d); + __m128 bias2 = _mm_loadu_ps(bias_d + C4NUM); + dst1 = _mm_add_ps(dst1, bias1); + dst2 = _mm_add_ps(dst2, bias2); + dst3 = _mm_add_ps(dst3, bias1); + dst4 = _mm_add_ps(dst4, bias2); + dst5 = _mm_add_ps(dst5, bias1); + dst6 = _mm_add_ps(dst6, bias2); + dst7 = _mm_add_ps(dst7, bias1); + dst8 = _mm_add_ps(dst8, bias2); + bias_d += C8NUM; + } + if (act_type == 3) { + __m128 relu6 = _mm_set_ps(6.0, 6.0, 6.0, 6.0); + dst1 = _mm_min_ps(dst1, relu6); + dst2 = _mm_min_ps(dst2, relu6); + dst3 = _mm_min_ps(dst3, relu6); + dst4 = _mm_min_ps(dst4, relu6); + dst5 = _mm_min_ps(dst5, relu6); + dst6 = _mm_min_ps(dst6, relu6); + dst7 = _mm_min_ps(dst7, relu6); + dst8 = _mm_min_ps(dst8, relu6); + } + if (act_type == 1 || act_type == 3) { + __m128 zero = _mm_setzero_ps(); + dst1 = _mm_max_ps(dst1, zero); + dst2 = _mm_max_ps(dst2, zero); + dst3 = _mm_max_ps(dst3, zero); + dst4 = _mm_max_ps(dst4, zero); + dst5 = _mm_max_ps(dst5, zero); + dst6 = _mm_max_ps(dst6, zero); + dst7 = _mm_max_ps(dst7, zero); + dst8 = _mm_max_ps(dst8, zero); + } + if (write_mode == 2) { // WriteWino + c = dst + WinoSteps2; + _mm_store_ps(dst, dst1); + _mm_store_ps(dst + 4, dst2); + dst += WinoSteps1; + _mm_store_ps(dst, dst3); + _mm_store_ps(dst + 4, dst4); + dst += WinoSteps1; + _mm_store_ps(dst, dst5); + _mm_store_ps(dst + 4, dst6); + dst += WinoSteps1; + _mm_store_ps(dst, dst7); + _mm_store_ps(dst + 4, dst8); + } else if (write_mode == 0) { // WriteC8 + _mm_store_ps(c, dst1); + _mm_store_ps(c + 4, dst2); + _mm_store_ps(c + 8, dst3); + _mm_store_ps(c + 12, dst4); + _mm_store_ps(c + 16, dst5); + _mm_store_ps(c + 20, dst6); + _mm_store_ps(c + 24, dst7); + _mm_store_ps(c + 28, dst8); + c += C8Steps; + } else { + switch (cc) { + case 1: // write1 + c = dst + 1; + _mm_store_ss(dst, dst1); + if (r > 1) { + dst += stride; + _mm_store_ss(dst, dst3); + } + if (r > 2) { + dst += stride; + _mm_store_ss(dst, dst5); + } + if (r > 3) { + dst += stride; + _mm_store_ss(dst, dst7); + dst += stride; + dst += 1; + } + break; + case 2: // write2 + c = dst + 2; + _mm_store_ss(dst, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst1); + if (r > 1) { + dst += stride; + _mm_store_ss(dst, dst3); + dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst3); + } + if (r > 2) { + dst += stride; + _mm_store_ss(dst, dst5); + dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst5); + } + if (r > 3) { + dst += stride; + _mm_store_ss(dst, dst7); + dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst7); + dst += stride; + dst += 2; + } + break; + case 3: // write3 + c = dst + 3; + _mm_store_ss(dst, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst1); + if (r > 1) { + dst += stride; + _mm_store_ss(dst, dst3); + dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst3); + dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst3); + } + if (r > 2) { + dst += stride; + _mm_store_ss(dst, dst5); + dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst5); + dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst5); + } + if (r > 3) { + dst += stride; + _mm_store_ss(dst, dst7); + dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst7); + dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst7); + dst += stride; + dst += 3; + } + break; + case 4: // write4 + c = dst + 4; + _mm_store_ps(dst, dst1); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + dst += stride; + dst += 4; + } + break; + case 5: // write5 + c = dst + 5; + _mm_store_ps(dst, dst1); + _mm_store_ss(dst + 4, dst2); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + _mm_store_ss(dst + 4, dst4); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + _mm_store_ss(dst + 4, dst6); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + _mm_store_ss(dst + 4, dst8); + dst += stride; + dst += 5; + } + break; + case 6: // write6 + c = dst + 6; + _mm_store_ps(dst, dst1); + _mm_store_ss(dst + 4, dst2); + dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst2); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + _mm_store_ss(dst + 4, dst4); + dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst4); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + _mm_store_ss(dst + 4, dst6); + dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst6); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + _mm_store_ss(dst + 4, dst8); + dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst8); + dst += stride; + dst += 6; + } + break; + case 7: // write7 + c = dst + 7; + _mm_store_ps(dst, dst1); + _mm_store_ss(dst + 4, dst2); + dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst2); + dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 6, dst2); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + _mm_store_ss(dst + 4, dst4); + dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst4); + dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 6, dst4); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + _mm_store_ss(dst + 4, dst6); + dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst6); + dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 6, dst6); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + _mm_store_ss(dst + 4, dst8); + dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst8); + dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 6, dst8); + dst += stride; + dst += 7; + } + break; + default: // write8 + c = dst + C8NUM; + _mm_store_ps(dst, dst1); + _mm_store_ps(dst + 4, dst2); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + _mm_store_ps(dst + 4, dst4); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + _mm_store_ps(dst + 4, dst6); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + _mm_store_ps(dst + 4, dst8); + dst += stride; + dst += C8NUM; + } + break; + } + } + if (cc <= C8NUM) { // write end + break; + } + } // col end + a += C4NUM * depth; + switch (write_mode) { + case 0: // C8DstStep + c += 32; + break; + case 2: + c = dst + WinoSteps2; + break; + default: + c = dst - col; + break; + } + if (r <= C4NUM) { + break; + } + } +} + +void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, + int col, int stride, size_t writeNhwc, size_t WriteWino) { + size_t DstWinoSteps = stride * C8NUM; + size_t WriteWinoSteps = stride * col; + for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) { + const float *srca_d = a; + float *dst = c; + for (int r = row; r > 0; r -= C4NUM) { + const float *srcb_d = b; + __m128 dst1 = _mm_setzero_ps(); + __m128 dst2 = _mm_setzero_ps(); + __m128 dst3 = _mm_setzero_ps(); + __m128 dst4 = _mm_setzero_ps(); + __m128 dst5 = _mm_setzero_ps(); + __m128 dst6 = _mm_setzero_ps(); + __m128 dst7 = _mm_setzero_ps(); + __m128 dst8 = _mm_setzero_ps(); + for (int d = 0; d < depth; d++) { + __m128 b1 = _mm_loadu_ps(srcb_d); + __m128 b2 = _mm_loadu_ps(srcb_d + 4); + __m128 a1 = _mm_load_ps1(srca_d); + __m128 a2 = _mm_load_ps1(srca_d + 1); + __m128 tmp1 = _mm_mul_ps(b1, a1); + __m128 tmp2 = _mm_mul_ps(b2, a1); + __m128 tmp3 = _mm_mul_ps(b1, a2); + __m128 tmp4 = _mm_mul_ps(b2, a2); + a1 = _mm_load_ps1(srca_d + 2); + dst1 = _mm_add_ps(dst1, tmp1); + dst2 = _mm_add_ps(dst2, tmp2); + a2 = _mm_load_ps1(srca_d + 3); + dst3 = _mm_add_ps(dst3, tmp3); + dst4 = _mm_add_ps(dst4, tmp4); + tmp1 = _mm_mul_ps(b1, a1); + tmp2 = _mm_mul_ps(b2, a1); + tmp3 = _mm_mul_ps(b1, a2); + tmp4 = _mm_mul_ps(b2, a2); + dst5 = _mm_add_ps(dst5, tmp1); + dst6 = _mm_add_ps(dst6, tmp2); + dst7 = _mm_add_ps(dst7, tmp3); + dst8 = _mm_add_ps(dst8, tmp4); + srcb_d += C8NUM; + srca_d += C4NUM; + } + if (bias != NULL) { + __m128 bias1 = _mm_loadu_ps(bias); + __m128 bias2 = _mm_loadu_ps(bias + C4NUM); + dst1 = _mm_add_ps(dst1, bias1); + dst2 = _mm_add_ps(dst2, bias2); + dst3 = _mm_add_ps(dst3, bias1); + dst4 = _mm_add_ps(dst4, bias2); + dst5 = _mm_add_ps(dst5, bias1); + dst6 = _mm_add_ps(dst6, bias2); + dst7 = _mm_add_ps(dst7, bias1); + dst8 = _mm_add_ps(dst8, bias2); + } + if (act_type == 3) { + __m128 relu6 = _mm_set_ps(6.0, 6.0, 6.0, 6.0); + dst1 = _mm_min_ps(dst1, relu6); + dst2 = _mm_min_ps(dst2, relu6); + dst3 = _mm_min_ps(dst3, relu6); + dst4 = _mm_min_ps(dst4, relu6); + dst5 = _mm_min_ps(dst5, relu6); + dst6 = _mm_min_ps(dst6, relu6); + dst7 = _mm_min_ps(dst7, relu6); + dst8 = _mm_min_ps(dst8, relu6); + } + if (act_type == 1 || act_type == 3) { + __m128 zero = _mm_setzero_ps(); + dst1 = _mm_max_ps(dst1, zero); + dst2 = _mm_max_ps(dst2, zero); + dst3 = _mm_max_ps(dst3, zero); + dst4 = _mm_max_ps(dst4, zero); + dst5 = _mm_max_ps(dst5, zero); + dst6 = _mm_max_ps(dst6, zero); + dst7 = _mm_max_ps(dst7, zero); + dst8 = _mm_max_ps(dst8, zero); + } + if (WriteWino != 0) { // WriteWino + _mm_store_ps(dst, dst1); + _mm_store_ps(dst + 4, dst2); + dst += WriteWinoSteps; + _mm_store_ps(dst, dst3); + _mm_store_ps(dst + 4, dst4); + dst += WriteWinoSteps; + _mm_store_ps(dst, dst5); + _mm_store_ps(dst + 4, dst6); + dst += WriteWinoSteps; + _mm_store_ps(dst, dst7); + _mm_store_ps(dst + 4, dst8); + dst += WriteWinoSteps; + } else if (writeNhwc == 0) { // WriteC8 + _mm_store_ps(dst, dst1); + _mm_store_ps(dst + 4, dst2); + _mm_store_ps(dst + 8, dst3); + _mm_store_ps(dst + 12, dst4); + _mm_store_ps(dst + 16, dst5); + _mm_store_ps(dst + 20, dst6); + _mm_store_ps(dst + 24, dst7); + _mm_store_ps(dst + 28, dst8); + dst += 32; + c = dst; + } else { + switch (col) { + case 1: // write1 + _mm_store_ss(dst, dst1); + if (r > 1) { + dst += stride; + _mm_store_ss(dst, dst3); + } + if (r > 2) { + dst += stride; + _mm_store_ss(dst, dst5); + } + if (r > 3) { + dst += stride; + _mm_store_ss(dst, dst7); + dst += stride; + } + case 2: // write2 + _mm_store_ss(dst, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst, dst1); + if (r > 1) { + dst += stride; + _mm_store_ss(dst, dst3); + dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst, dst3); + } + if (r > 2) { + dst += stride; + _mm_store_ss(dst, dst5); + dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst, dst5); + } + if (r > 3) { + dst += stride; + _mm_store_ss(dst, dst7); + dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst, dst7); + } + case 3: // write3 + _mm_store_ss(dst, dst1); + dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst1); + dst1 = _mm_shuffle_ps(dst1 + 2, dst1, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst, dst1); + if (r > 1) { + dst += stride; + _mm_store_ss(dst, dst3); + dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst3); + dst3 = _mm_shuffle_ps(dst3, dst3, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst3); + } + if (r > 2) { + dst += stride; + _mm_store_ss(dst, dst5); + dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst5); + dst5 = _mm_shuffle_ps(dst5, dst5, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst5); + } + if (r > 3) { + dst += stride; + _mm_store_ss(dst, dst7); + dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 1, dst7); + dst7 = _mm_shuffle_ps(dst7, dst7, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 2, dst7); + dst += stride; + } + case 4: // write4 + _mm_store_ps(dst, dst1); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + dst += stride; + } + case 5: // // write5 + _mm_store_ps(dst, dst1); + _mm_store_ss(dst + 4, dst2); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + _mm_store_ss(dst + 4, dst4); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + _mm_store_ss(dst + 4, dst6); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + _mm_store_ss(dst + 4, dst8); + dst += stride; + } + case 6: // write6 + _mm_store_ps(dst, dst1); + _mm_store_ss(dst + 4, dst2); + dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst2); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + _mm_store_ss(dst + 4, dst4); + dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst4); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + _mm_store_ss(dst + 4, dst6); + dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst6); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + _mm_store_ss(dst + 4, dst8); + dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst8); + dst += stride; + } + case 7: // write7 + _mm_store_ps(dst, dst1); + _mm_store_ss(dst + 4, dst2); + dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst2); + dst2 = _mm_shuffle_ps(dst2, dst2, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 6, dst2); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + _mm_store_ss(dst + 4, dst4); + dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst4); + dst4 = _mm_shuffle_ps(dst4, dst4, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 6, dst4); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + _mm_store_ss(dst + 4, dst6); + dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst6); + dst6 = _mm_shuffle_ps(dst6, dst6, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 6, dst6); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + _mm_store_ss(dst + 4, dst8); + dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 5, dst8); + dst8 = _mm_shuffle_ps(dst8, dst8, _MM_SHUFFLE(0, 3, 2, 1)); + _mm_store_ss(dst + 6, dst8); + dst += stride; + } + default: // write8 + _mm_store_ps(dst, dst1); + _mm_store_ps(dst + 4, dst2); + if (r > 1) { + dst += stride; + _mm_store_ps(dst, dst3); + _mm_store_ps(dst + 4, dst4); + } + if (r > 2) { + dst += stride; + _mm_store_ps(dst, dst5); + _mm_store_ps(dst + 4, dst6); + } + if (r > 3) { + dst += stride; + _mm_store_ps(dst, dst7); + _mm_store_ps(dst + 4, dst8); + dst += stride; + } + } + } + if (r <= C4NUM) { // WriteEnd + break; + } + } + b += depth * C8NUM; + bias += (bias != NULL) ? C8NUM : 0; + if (WriteWino != 0) { + c += DstWinoSteps; + } else if (writeNhwc != 0) { + c += C8NUM; + } + if (col_tmp <= C8NUM) { + break; + } + } +} +#endif diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 4e692ae1d8..b3ce4be0d7 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -76,6 +76,16 @@ if (ENABLE_FP16) ${KERNEL_OP_FP16_SRC} ) endif () + +if ("${X86_64_SIMD}" STREQUAL "sse") + file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c) + set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C) + set(KERNEL_OP_SRC + ${KERNEL_OP_SRC} + ${TEST_ASSEMBLY_SRC} + ) +endif() + ### gpu kernel if (SUPPORT_GPU) file(GLOB GPU_KERNEL_OP_SRC