diff --git a/mindspore/lite/nnacl/assembly/avx/ConvDwFp32BorderAvx.S b/mindspore/lite/nnacl/assembly/avx/ConvDwFp32BorderAvx.S new file mode 100644 index 0000000000..5d3bd03c06 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/avx/ConvDwFp32BorderAvx.S @@ -0,0 +1,177 @@ +#ifdef ENABLE_AVX + .text + .align 4 + .global ConvDwFp32Border +#ifndef __APPLE__ +#ifndef WIN32 + .type ConvDwFp32Border, %function +#endif +#endif + +// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, +// size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, +// size_t relu6); + +ConvDwFp32Border: + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 + pushq %r8 // -64 + pushq %rcx // -72 + pushq %rdx // -80 + pushq %rsi + pushq %rdi + addq $96, %rsp + + movq %rdi, %rdx +#ifdef WIN32 + movq %rcx, %rdx +#endif + movq 8(%rdx), %r12 // src + movq 16(%rdx), %r13 // weight + movq 24(%rdx), %rbp // bias + movq 32(%rdx), %r11 // height + movq 40(%rdx), %r10 + movq %r10, -72(%rsp) // width + movq 48(%rdx), %r10 + movq %r10, -80(%rsp) // in_kh_step + movq 56(%rdx), %r10 // in_kw_step + movq 64(%rdx), %rax // kernel_w + movq 72(%rdx), %rcx // relu + movq 80(%rdx), %rbx // reul6 + movq $6, -64(%rsp) + movq (%rdx), %rdx + cmpq $0, %r11 + je End + + xorps %xmm8, %xmm8 + LoopHeight: + movq %r12, %rsi // src_kh, src_kw + movq %r13, %rdi // weight_kh, weight_kw + movq -72(%rsp), %r8 // width + + cmpq $6, %r8 + jae LoopWidth6 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $1, %r8 + jae LoopWidth1 + jmp LoopWidthEnd + + LoopWidth6: + xorps %xmm6, %xmm6 + xorps %xmm7, %xmm7 + imul $3, %r10, %r9 + addq %rsi, %r9 + vmovups (%rsi), %xmm0 // src_kw + vmovups (%rsi, %r10), %xmm1 + vmovups (%rsi, %r10, 2), %xmm2 + vmovups (%r9), %xmm3 + vmovups (%rsi, %r10, 4), %xmm4 + vmovups (%r9, %r10, 2), %xmm5 + + vfmadd231ps (%rdi), %xmm0, %xmm6 + vfmadd231ps 16(%rdi), %xmm1, %xmm7 + vfmadd231ps 32(%rdi), %xmm2, %xmm8 + vfmadd231ps 48(%rdi), %xmm3, %xmm6 + vfmadd231ps 64(%rdi), %xmm4, %xmm7 + vfmadd231ps 80(%rdi), %xmm5, %xmm8 + + addps %xmm6, %xmm7 + imul $6, %r10, %r15 + addq $96, %rdi + addps %xmm7, %xmm8 + addq %r15, %rsi + + subq $6, %r8 + cmpq $6, %r8 + jae LoopWidth6 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $0, %r8 + je LoopWidthEnd + jmp LoopWidth1 + + LoopWidth4: + xorps %xmm6, %xmm6 + xorps %xmm7, %xmm7 + imul $3, %r10, %r9 + addq %rsi, %r9 + vmovups (%rsi), %xmm0 // src_kw + vmovups (%rsi, %r10, 1), %xmm1 + vmovups (%rsi, %r10, 2), %xmm2 + vmovups (%r9), %xmm3 + + vfmadd231ps (%rdi), %xmm0, %xmm6 + vfmadd231ps 16(%rdi), %xmm1, %xmm7 + vfmadd231ps 32(%rdi), %xmm2, %xmm8 + vfmadd231ps 48(%rdi), %xmm3, %xmm6 + + addps %xmm6, %xmm7 + imul $4, %r10, %r15 + addq $64, %rdi + addps %xmm7, %xmm8 + addq %r15, %rsi + + subq $4, %r8 + cmpq $4, %r8 + jae LoopWidth4 + cmpq $0, %r8 + je LoopWidthEnd + jmp LoopWidth1 + + LoopWidth1: + vmovups (%rsi), %xmm0 // input_tmp + addq %r10, %rsi + vfmadd231ps (%rdi), %xmm0, %xmm8 + addq $16, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopWidth1 + jmp LoopWidthEnd + + LoopWidthEnd: + subq $1, %r11 + cmpq $0, %r11 + je LoopHeightEnd + addq -80(%rsp), %r12 // in_kh_step + addq %rax, %r13 // kernel_w_step + jmp LoopHeight + + LoopHeightEnd: + xorps %xmm10, %xmm10 + vbroadcastss -64(%rsp), %xmm9 + + addps (%rbp), %xmm8 + cmpq $1, %rbx + je Relu6 + cmpq $1, %rcx + je Relu + jmp Write + Relu6: + minps %xmm9, %xmm8 + Relu: + maxps %xmm10, %xmm8 + Write: + movups %xmm8, (%rdx) +End: + subq $96, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore/lite/nnacl/assembly/avx/ConvDwFp32RowAvx.S b/mindspore/lite/nnacl/assembly/avx/ConvDwFp32RowAvx.S new file mode 100644 index 0000000000..6896b78a68 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/avx/ConvDwFp32RowAvx.S @@ -0,0 +1,178 @@ +#ifdef ENABLE_AVX + .text + .align 4 + .global ConvDwFp32Row +#ifndef __APPLE__ +#ifndef WIN32 + .type ConvDwFp32Row, %function +#endif +#endif + +// void ConvDwFp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels, +// size_t output_channel, size_t input_step); +// in linux x64 platform: +// rdi: output_ptr +// rsi: input_ptr +// rdx: weight_ptr +// rcx: num_pixels +// r8: output_channel +// r9: input_step + +// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites +// rcx: output_ptr +// rdx: input_ptr +// r8: weight_ptr +// r9: num_pixels +// 40: output_channel +// 48: input_step + +ConvDwFp32Row: + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rsi + pushq %rdi + addq $48, %rsp + +#ifdef WIN32 + movq %rcx, %rdi // output_ptr + movq %rdx, %rsi // input_ptr + movq %r8, %rdx // weight_ptr + movq %r9, %rcx // num_pixels + movq 40(%rsp), %r8 // output_channel + movq 48(%rsp), %r9 // input_step +#endif + + movq $4, %r13 + imul %r13, %r9 + movq %rsi, %r13 // input_ptr + movq %rdx, %r14 // weight_ptr + movq %r8, %r15 // output_channel + cmpq $0, %rcx + je End + + LoopPixel: + movq %r13, %rsi // input_tmp + movq %r14, %rdx // weight_tmp + movq %r15, %r8 // channel_tmp + + cmpq $32, %r8 + jae LoopC32 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC32: + vmovups (%rsi), %ymm0 // input_tmp + vmovups 32(%rsi), %ymm1 + vmovups 64(%rsi), %ymm2 + vmovups 96(%rsi), %ymm3 + + vmovups (%rdi), %ymm8 // output_tmp + vmovups 32(%rdi), %ymm9 + vmovups 64(%rdi), %ymm10 + vmovups 96(%rdi), %ymm11 + + addq $128, %rsi + vfmadd231ps (%rdx), %ymm0, %ymm8 + vfmadd231ps 32(%rdx), %ymm1, %ymm9 + vfmadd231ps 64(%rdx), %ymm2, %ymm10 + vfmadd231ps 96(%rdx), %ymm3, %ymm11 + + vmovups %ymm8, (%rdi) // output_ptr + vmovups %ymm9, 32(%rdi) + vmovups %ymm10, 64(%rdi) + vmovups %ymm11, 96(%rdi) + addq $128, %rdi + addq $128, %rdx + + subq $32, %r8 + cmpq $32, %r8 + jae LoopC32 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC16: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + vmovups 32(%rsi), %ymm1 + vmovups 32(%rdi), %ymm9 + addq $64, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + vfmadd231ps 32(%rdx), %ymm1, %ymm9 + + vmovups %ymm8, (%rdi) // output_ptr + addq $64, %rdx + vmovups %ymm9, 32(%rdi) + addq $64, %rdi + + subq $16, %r8 + cmpq $16, %r8 + jae LoopC16 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC8: + vmovups (%rsi), %ymm0 // input_tmp + vmovups (%rdi), %ymm8 // output_tmp + addq $32, %rsi + + vfmadd231ps (%rdx), %ymm0, %ymm8 + + addq $32, %rdx + vmovups %ymm8, (%rdi) + addq $32, %rdi + + subq $8, %r8 + cmpq $8, %r8 + jae LoopC8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopC: + vmovss (%rsi), %xmm0 // input_tmp + vmovss (%rdi), %xmm8 // output_ptr + + vfmadd231ss (%rdx), %xmm0, %xmm8 + + addq $4, %rsi + addq $4, %rdx + vmovss %xmm8, (%rdi) + addq $4, %rdi + + subq $1, %r8 + cmpq $0, %r8 + ja LoopC + jmp LoopCEnd + + LoopCEnd: + subq $1, %rcx // num_pixel -= 1 + cmpq $0, %rcx + je End + addq %r9, %r13 + jmp LoopPixel +End: + subq $48, %rsp + popq %rdi + popq %rsi + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif diff --git a/mindspore/lite/nnacl/fp32/common_func_fp32.h b/mindspore/lite/nnacl/fp32/common_func_fp32.h index 70a3f8a928..fc6f75ebe9 100644 --- a/mindspore/lite/nnacl/fp32/common_func_fp32.h +++ b/mindspore/lite/nnacl/fp32/common_func_fp32.h @@ -21,6 +21,20 @@ #include "nnacl/op_base.h" #include "nnacl/conv_parameter.h" +typedef struct ConvDwFp32BorderParam { + float *dst; + const float *src; + const float *weight; + const float *bias; + size_t height; + size_t width; + size_t in_kh_step; + size_t in_kw_step; + size_t kernel_w; + size_t relu; + size_t relu6; +} ConvDwFp32BorderParam; + #ifdef __cplusplus extern "C" { #endif @@ -37,8 +51,12 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6); +#ifdef ENABLE_AVX +void ConvDwFp32Border(ConvDwFp32BorderParam *param); +#else void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); +#endif void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c index 7b9ff25553..1fee1230f7 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c @@ -202,8 +202,21 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; - -#if defined(ENABLE_ARM) || defined(ENABLE_SSE) +#ifdef ENABLE_AVX + ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam)); + param->dst = dst_kernel; + param->src = src_kernel; + param->weight = weight_kernel; + param->bias = bias; + param->height = end_kh - start_kh; + param->width = end_kw - start_kw; + param->in_kh_step = sliding->in_kh_step_ * sizeof(float); + param->in_kw_step = sliding->in_kw_step_ * sizeof(float); + param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float); + param->relu = relu; + param->relu6 = relu6; + ConvDwFp32Border(param); +#elif defined(ENABLE_ARM) || defined(ENABLE_SSE) ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float), conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6); diff --git a/mindspore/lite/nnacl/intrinsics/sse/ConvDwFp32Row_sse.c b/mindspore/lite/nnacl/intrinsics/sse/ConvDwFp32Row_sse.c index ccf1a72395..e254e95a09 100644 --- a/mindspore/lite/nnacl/intrinsics/sse/ConvDwFp32Row_sse.c +++ b/mindspore/lite/nnacl/intrinsics/sse/ConvDwFp32Row_sse.c @@ -14,7 +14,7 @@ * limitations under the License. */ -#ifdef ENABLE_SSE +#if defined(ENABLE_SSE) && !defined(ENABLE_AVX) #include #include "nnacl/fp32/common_func_fp32.h" diff --git a/mindspore/lite/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c b/mindspore/lite/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c index 587bd676fa..2c6893bf9d 100644 --- a/mindspore/lite/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c +++ b/mindspore/lite/nnacl/intrinsics/sse/DepthwiseFp32_Sse.c @@ -19,6 +19,7 @@ #include "nnacl/fp32/conv_depthwise_fp32.h" #include "nnacl/intrinsics/sse/sse_common.h" +#ifndef ENABLE_AVX void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) { in_kh_step /= sizeof(float); @@ -104,6 +105,7 @@ void ConvDwFp32Border(float *dst, const float *src, const float *weight, const f } _mm_storeu_ps(dst, dst_ma); } +#endif void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,