From 3f57ef07ee167d30fe4c1bc6a0ecd4d3df803e31 Mon Sep 17 00:00:00 2001 From: lixian Date: Fri, 18 Dec 2020 14:37:54 +0800 Subject: [PATCH] add avx fp32 depth-wise conv kernel --- .../nnacl/assembly/avx/ConvDwFp32Avx3x3.S | 273 ++++++++++++++++++ mindspore/lite/nnacl/common_func.c | 49 ++++ mindspore/lite/nnacl/common_func.h | 6 + .../lite/nnacl/fp32/conv_depthwise_fp32.c | 42 +++ .../lite/nnacl/fp32/conv_depthwise_fp32.h | 2 + .../arm/fp32/convolution_depthwise_fp32.cc | 7 +- 6 files changed, 378 insertions(+), 1 deletion(-) create mode 100644 mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S diff --git a/mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S b/mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S new file mode 100644 index 0000000000..832df26aa9 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/avx/ConvDwFp32Avx3x3.S @@ -0,0 +1,273 @@ +#ifdef ENABLE_AVX +#ifndef WIN32 + +.text +.align 4 +.global ConvDwFp32Avx3x3 +#ifndef __APPLE__ +.type ConvDwFp32Avx3x3, %function +#endif + +// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, +// size_t input_stride, size_t relu) +// rdi: output +// rsi: input +// rdx: weights +// rcx: bias +// r8: channels +// r9: output_width +// 8: input_stride +// 16: relu +// 24: relu6 + +ConvDwFp32Avx3x3: + pushq %r15 + pushq %r14 + pushq %r13 + pushq %r12 + pushq %rbx + pushq %rbp + pushq %r9 + pushq %r8 + pushq %rcx + pushq %rdx + pushq %rsi + pushq %rdi + addq $96, %rsp + + movq $6, %rax + vcvtsi2ss %rax, %xmm15, %xmm15 + vshufps $0, %xmm15, %xmm15, %xmm15 + vinsertf128 $1, %xmm15, %ymm15, %ymm15 + vxorps %ymm14, %ymm14, %ymm14 + + LoopPixel: + movq -80(%rsp), %rdx + movq -72(%rsp), %rcx + movq -64(%rsp), %r8 + movq (%rsi), %r9 + movq 8(%rsi), %r10 + movq 16(%rsi), %r11 + movq 24(%rsi), %r12 + movq 32(%rsi), %r13 + movq 40(%rsi), %r14 + movq 48(%rsi), %r15 + movq 56(%rsi), %rbp + movq 64(%rsi), %rbx + + vmovups (%r9), %ymm0 + addq $32, %r9 + vmovups (%r10), %ymm1 + addq $32, %r10 + vmovups (%r11), %ymm2 + addq $32, %r11 + + vmovups (%rdx), %ymm11 + addq $32, %rdx + vmovups (%rdx), %ymm12 + addq $32, %rdx + vmovups (%rdx), %ymm13 + addq $32, %rdx + + vmovups (%rcx), %ymm10 + addq $32, %rcx + + cmpq $8, %r8 + jbe LeftLoop + LoopC8: + vfmadd231ps %ymm11, %ymm0, %ymm10 + vmovups (%r12), %ymm3 + addq $32, %r12 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm1, %ymm10 + vmovups (%r13), %ymm4 + addq $32, %r13 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm2, %ymm10 + vmovups (%r14), %ymm5 + addq $32, %r14 + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm3, %ymm10 + vmovups (%r15), %ymm6 + addq $32, %r15 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm4, %ymm10 + vmovups (%rbp), %ymm7 + addq $32, %rbp + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm5, %ymm10 + vmovups (%rbx), %ymm8 + addq $32, %rbx + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm6, %ymm10 + vmovups (%r9), %ymm0 + addq $32, %r9 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm7, %ymm10 + vmovups (%r10), %ymm1 + addq $32, %r10 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm8, %ymm10 + vmovups (%r11), %ymm2 + addq $32, %r11 + vmovups (%rdx), %ymm13 + addq $32, %rdx + + movq 24(%rsp), %rax + cmpq $0, %rax + jne Relu6 + movq 16(%rsp), %rax + cmpq $0, %rax + jne Relu + jmp Write + Relu6: + vminps %ymm15, %ymm10, %ymm10 + Relu: + vmaxps %ymm14, %ymm10, %ymm10 + Write: + vmovups %ymm10, (%rdi) + addq $32, %rdi + + vmovups (%rcx), %ymm10 + addq $32, %rcx + subq $8, %r8 + cmpq $8, %r8 + ja LoopC8 + + LeftLoop: + vfmadd231ps %ymm11, %ymm0, %ymm10 + vmovups (%r12), %ymm3 + addq $32, %r12 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm1, %ymm10 + vmovups (%r13), %ymm4 + addq $32, %r13 + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm2, %ymm10 + vmovups (%r14), %ymm5 + addq $32, %r14 + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm3, %ymm10 + vmovups (%r15), %ymm6 + addq $32, %r15 + vmovups (%rdx), %ymm11 + addq $32, %rdx + vfmadd231ps %ymm12, %ymm4, %ymm10 + vmovups (%rbp), %ymm7 + addq $32, %rbp + vmovups (%rdx), %ymm12 + addq $32, %rdx + vfmadd231ps %ymm13, %ymm5, %ymm10 + vmovups (%rbx), %ymm8 + addq $32, %rbx + vmovups (%rdx), %ymm13 + addq $32, %rdx + vfmadd231ps %ymm11, %ymm6, %ymm10 + vfmadd231ps %ymm12, %ymm7, %ymm10 + vfmadd231ps %ymm13, %ymm8, %ymm10 + + movq 24(%rsp), %rax + cmpq $0, %rax + jne LeftRelu6 + movq 16(%rsp), %rax + cmpq $0, %rax + jne LeftRelu + jmp LeftWrite + LeftRelu6: + vminps %ymm15, %ymm10, %ymm10 + LeftRelu: + vmaxps %ymm14, %ymm10, %ymm10 + LeftWrite: + cmpq $1, %r8 + je Write1 + cmpq $2, %r8 + je Write2 + cmpq $3, %r8 + je Write3 + cmpq $4, %r8 + je Write4 + cmpq $5, %r8 + je Write5 + cmpq $6, %r8 + je Write6 + cmpq $7, %r8 + je Write7 + jmp Write8 + Write1: + vmovss %xmm10, (%rdi) + addq $4, %rdi + jmp NextPixel + Write2: + vmovsd %xmm10, (%rdi) + addq $8, %rdi + jmp NextPixel + Write3: + vmovsd %xmm10, (%rdi) + movhlps %xmm10, %xmm10 + vmovss %xmm10, 8(%rdi) + addq $12, %rdi + jmp NextPixel + Write4: + vmovups %xmm10, (%rdi) + addq $16, %rdi + jmp NextPixel + Write5: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovss %xmm9, 16(%rdi) + addq $20, %rdi + jmp NextPixel + Write6: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovsd %xmm9, 16(%rdi) + addq $24, %rdi + jmp NextPixel + Write7: + vmovups %xmm10, (%rdi) + vextractf128 $1, %ymm10, %xmm9 + vmovsd %xmm9, 16(%rdi) + movhlps %xmm9, %xmm9 + vmovss %xmm9, 24(%rdi) + addq $28, %rdi + jmp NextPixel + Write8: + vmovups %ymm10, (%rdi) + add $32, %rdi + + NextPixel: + movq 8(%rsp), %rbp + addq %rbp, %rsi + movq -56(%rsp), %rax + subq $1, %rax + movq %rax, -56(%rsp) + cmpq $0, %rax + ja LoopPixel +End: + subq $96, %rsp + popq %rdi + popq %rsi + popq %rdx + popq %rcx + popq %r8 + popq %r9 + popq %rbp + popq %rbx + popq %r12 + popq %r13 + popq %r14 + popq %r15 + retq +#endif +#endif diff --git a/mindspore/lite/nnacl/common_func.c b/mindspore/lite/nnacl/common_func.c index 10edd92619..6dfbb1c480 100644 --- a/mindspore/lite/nnacl/common_func.c +++ b/mindspore/lite/nnacl/common_func.c @@ -78,3 +78,52 @@ void Relu6Fp32(float *data, float *dst, int ele_num) { data[j] = data[j] > 6 ? 6 : data[j]; } } + +#ifdef ENABLE_AVX +#ifdef WIN32 +void ReluFp32C8(float *data, float *dst, int ele_num) { + int four_block = UP_DIV(ele_num, C8NUM); + for (int i = 0; i < four_block - 1; i++) { + int index = i * C8NUM; + data[index] = data[index] < 0 ? 0 : data[index]; + data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; + data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; + data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; + data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4]; + data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5]; + data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6]; + data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7]; + } + for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) { + data[j] = data[j] < 0 ? 0 : data[j]; + } +} + +void Relu6Fp32C8(float *data, float *dst, int ele_num) { + int four_block = UP_DIV(ele_num, C8NUM); + for (int i = 0; i < four_block - 1; i++) { + int index = i * C8NUM; + data[index] = data[index] < 0 ? 0 : data[index]; + data[index] = data[index] > 6 ? 6 : data[index]; + data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1]; + data[index + 1] = data[index + 1] > 6 ? 6 : data[index + 1]; + data[index + 2] = data[index + 2] < 0 ? 0 : data[index + 2]; + data[index + 2] = data[index + 2] > 6 ? 6 : data[index + 2]; + data[index + 3] = data[index + 3] < 0 ? 0 : data[index + 3]; + data[index + 3] = data[index + 3] > 6 ? 6 : data[index + 3]; + data[index + 4] = data[index + 4] < 0 ? 0 : data[index + 4]; + data[index + 4] = data[index + 4] > 6 ? 6 : data[index + 4]; + data[index + 5] = data[index + 5] < 0 ? 0 : data[index + 5]; + data[index + 5] = data[index + 5] > 6 ? 6 : data[index + 5]; + data[index + 6] = data[index + 6] < 0 ? 0 : data[index + 6]; + data[index + 6] = data[index + 6] > 6 ? 6 : data[index + 6]; + data[index + 7] = data[index + 7] < 0 ? 0 : data[index + 7]; + data[index + 7] = data[index + 7] > 6 ? 6 : data[index + 7]; + } + for (int j = (four_block - 1) * C8NUM; j < ele_num; ++j) { + data[j] = data[j] < 0 ? 0 : data[j]; + data[j] = data[j] > 6 ? 6 : data[j]; + } +} +#endif +#endif diff --git a/mindspore/lite/nnacl/common_func.h b/mindspore/lite/nnacl/common_func.h index 31f78e7ef4..2173d11fbd 100644 --- a/mindspore/lite/nnacl/common_func.h +++ b/mindspore/lite/nnacl/common_func.h @@ -31,6 +31,12 @@ int8_t MinInt8(int8_t a, int8_t b); int8_t MaxInt8(int8_t a, int8_t b); void ReluFp32(float *data, float *dst, int ele_num); void Relu6Fp32(float *data, float *dst, int ele_num); +#ifdef ENABLE_AVX +#ifdef WIN32 +void ReluFp32C8(float *data, float *dst, int ele_num); +void Relu6Fp32C8(float *data, float *dst, int ele_num); +#endif +#endif int offset(const int *shape, const int dim0, const int dim1, const int dim2, const int dim3); int offsetComm(const int *shape, const int dim0, const int dim1, const int dim2); int offset4d(const int *shape, const int *dims); diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c index a5f7847ee3..77ad4b9140 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c @@ -681,6 +681,47 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c #endif #ifdef ENABLE_AVX +#ifdef WIN32 +void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, + int output_width, int input_stride, bool relu, bool relu6, int kernel) { + do { + float *in[kernel]; + for (int k = 0; k < kernel; k++) { + in[k] = input[k]; + } + input = input + input_stride; + + size_t c = channels; + const float *w = weights; + float *out = output; + memcpy(out, bias, channels * sizeof(float)); + for (; c >= C8NUM; c -= C8NUM) { + for (int i = 0; i < C8NUM; i++) { + for (int k = 0; k < kernel; k++) { + out[i] += in[k][i] * w[i + k * C8NUM]; + } + } + w += kernel * C8NUM; + out += C8NUM; + for (int k = 0; k < kernel; k++) { + in[k] += C8NUM; + } + } + for (int i = 0; i < c; i++) { + for (int k = 0; k < kernel; k++) { + out[i] += in[k][i] * w[i + k * C8NUM]; + } + } + if (relu) { + ReluFp32C8(output, output, channels); + } + if (relu6) { + Relu6Fp32C8(output, output, channels); + } + output += channels; + } while (--output_width != 0); +} +#else void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, int input_stride, bool relu, bool relu6, int kernel) { if (kernel == 9) { @@ -688,6 +729,7 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c } } #endif +#endif void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data, float *zero_ptr, const ConvParameter *conv_param, int task_id) { diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h index f81077fce8..3ea0e52dfb 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h @@ -67,9 +67,11 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c #endif #ifdef ENABLE_AVX +#ifndef WIN32 void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, size_t input_stride, size_t relu, size_t relu6); #endif +#endif void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, int input_stride, bool relu, bool relu6, int kernel); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc index 305f2716c0..a88b43cd3d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc @@ -147,7 +147,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector conv_param->input_channel_ = inputs[kInputIndex]->Channel(); conv_param->output_h_ = outputs[kOutputIndex]->Height(); conv_param->output_w_ = outputs[kOutputIndex]->Width(); -#if defined(ENABLE_ARM64) || defined(ENABLE_AVX) +#ifdef ENABLE_AVX + if (conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3) { + kernel = + new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive); + } +#elif defined(ENABLE_ARM64) if (CheckConvDwUseIndirectBuffer(conv_param)) { kernel = new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx, primitive);