depthwise avx optimize

pull/12771/head
lzk 4 years ago
parent 4e1931481a
commit e3976405b3

@ -0,0 +1,177 @@
#ifdef ENABLE_AVX
.text
.align 4
.global ConvDwFp32Border
#ifndef __APPLE__
#ifndef WIN32
.type ConvDwFp32Border, %function
#endif
#endif
// void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height,
// size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu,
// size_t relu6);
ConvDwFp32Border:
pushq %r15
pushq %r14
pushq %r13
pushq %r12
pushq %rbx
pushq %rbp
pushq %r9
pushq %r8 // -64
pushq %rcx // -72
pushq %rdx // -80
pushq %rsi
pushq %rdi
addq $96, %rsp
movq %rdi, %rdx
#ifdef WIN32
movq %rcx, %rdx
#endif
movq 8(%rdx), %r12 // src
movq 16(%rdx), %r13 // weight
movq 24(%rdx), %rbp // bias
movq 32(%rdx), %r11 // height
movq 40(%rdx), %r10
movq %r10, -72(%rsp) // width
movq 48(%rdx), %r10
movq %r10, -80(%rsp) // in_kh_step
movq 56(%rdx), %r10 // in_kw_step
movq 64(%rdx), %rax // kernel_w
movq 72(%rdx), %rcx // relu
movq 80(%rdx), %rbx // reul6
movq $6, -64(%rsp)
movq (%rdx), %rdx
cmpq $0, %r11
je End
xorps %xmm8, %xmm8
LoopHeight:
movq %r12, %rsi // src_kh, src_kw
movq %r13, %rdi // weight_kh, weight_kw
movq -72(%rsp), %r8 // width
cmpq $6, %r8
jae LoopWidth6
cmpq $4, %r8
jae LoopWidth4
cmpq $1, %r8
jae LoopWidth1
jmp LoopWidthEnd
LoopWidth6:
xorps %xmm6, %xmm6
xorps %xmm7, %xmm7
imul $3, %r10, %r9
addq %rsi, %r9
vmovups (%rsi), %xmm0 // src_kw
vmovups (%rsi, %r10), %xmm1
vmovups (%rsi, %r10, 2), %xmm2
vmovups (%r9), %xmm3
vmovups (%rsi, %r10, 4), %xmm4
vmovups (%r9, %r10, 2), %xmm5
vfmadd231ps (%rdi), %xmm0, %xmm6
vfmadd231ps 16(%rdi), %xmm1, %xmm7
vfmadd231ps 32(%rdi), %xmm2, %xmm8
vfmadd231ps 48(%rdi), %xmm3, %xmm6
vfmadd231ps 64(%rdi), %xmm4, %xmm7
vfmadd231ps 80(%rdi), %xmm5, %xmm8
addps %xmm6, %xmm7
imul $6, %r10, %r15
addq $96, %rdi
addps %xmm7, %xmm8
addq %r15, %rsi
subq $6, %r8
cmpq $6, %r8
jae LoopWidth6
cmpq $4, %r8
jae LoopWidth4
cmpq $0, %r8
je LoopWidthEnd
jmp LoopWidth1
LoopWidth4:
xorps %xmm6, %xmm6
xorps %xmm7, %xmm7
imul $3, %r10, %r9
addq %rsi, %r9
vmovups (%rsi), %xmm0 // src_kw
vmovups (%rsi, %r10, 1), %xmm1
vmovups (%rsi, %r10, 2), %xmm2
vmovups (%r9), %xmm3
vfmadd231ps (%rdi), %xmm0, %xmm6
vfmadd231ps 16(%rdi), %xmm1, %xmm7
vfmadd231ps 32(%rdi), %xmm2, %xmm8
vfmadd231ps 48(%rdi), %xmm3, %xmm6
addps %xmm6, %xmm7
imul $4, %r10, %r15
addq $64, %rdi
addps %xmm7, %xmm8
addq %r15, %rsi
subq $4, %r8
cmpq $4, %r8
jae LoopWidth4
cmpq $0, %r8
je LoopWidthEnd
jmp LoopWidth1
LoopWidth1:
vmovups (%rsi), %xmm0 // input_tmp
addq %r10, %rsi
vfmadd231ps (%rdi), %xmm0, %xmm8
addq $16, %rdi
subq $1, %r8
cmpq $0, %r8
ja LoopWidth1
jmp LoopWidthEnd
LoopWidthEnd:
subq $1, %r11
cmpq $0, %r11
je LoopHeightEnd
addq -80(%rsp), %r12 // in_kh_step
addq %rax, %r13 // kernel_w_step
jmp LoopHeight
LoopHeightEnd:
xorps %xmm10, %xmm10
vbroadcastss -64(%rsp), %xmm9
addps (%rbp), %xmm8
cmpq $1, %rbx
je Relu6
cmpq $1, %rcx
je Relu
jmp Write
Relu6:
minps %xmm9, %xmm8
Relu:
maxps %xmm10, %xmm8
Write:
movups %xmm8, (%rdx)
End:
subq $96, %rsp
popq %rdi
popq %rsi
popq %rdx
popq %rcx
popq %r8
popq %r9
popq %rbp
popq %rbx
popq %r12
popq %r13
popq %r14
popq %r15
retq
#endif

@ -0,0 +1,178 @@
#ifdef ENABLE_AVX
.text
.align 4
.global ConvDwFp32Row
#ifndef __APPLE__
#ifndef WIN32
.type ConvDwFp32Row, %function
#endif
#endif
// void ConvDwFp32Row(float *output_ptr, const float *input_tmp, const float *weight_ptr, size_t num_pixels,
// size_t output_channel, size_t input_step);
// in linux x64 platform:
// rdi: output_ptr
// rsi: input_ptr
// rdx: weight_ptr
// rcx: num_pixels
// r8: output_channel
// r9: input_step
// in win x64 platform: "shadow space" needs to be opened up for first four parameters ==> 32 bites
// rcx: output_ptr
// rdx: input_ptr
// r8: weight_ptr
// r9: num_pixels
// 40: output_channel
// 48: input_step
ConvDwFp32Row:
pushq %r15
pushq %r14
pushq %r13
pushq %r12
pushq %rsi
pushq %rdi
addq $48, %rsp
#ifdef WIN32
movq %rcx, %rdi // output_ptr
movq %rdx, %rsi // input_ptr
movq %r8, %rdx // weight_ptr
movq %r9, %rcx // num_pixels
movq 40(%rsp), %r8 // output_channel
movq 48(%rsp), %r9 // input_step
#endif
movq $4, %r13
imul %r13, %r9
movq %rsi, %r13 // input_ptr
movq %rdx, %r14 // weight_ptr
movq %r8, %r15 // output_channel
cmpq $0, %rcx
je End
LoopPixel:
movq %r13, %rsi // input_tmp
movq %r14, %rdx // weight_tmp
movq %r15, %r8 // channel_tmp
cmpq $32, %r8
jae LoopC32
cmpq $16, %r8
jae LoopC16
cmpq $8, %r8
jae LoopC8
cmpq $0, %r8
ja LoopC
jmp LoopCEnd
LoopC32:
vmovups (%rsi), %ymm0 // input_tmp
vmovups 32(%rsi), %ymm1
vmovups 64(%rsi), %ymm2
vmovups 96(%rsi), %ymm3
vmovups (%rdi), %ymm8 // output_tmp
vmovups 32(%rdi), %ymm9
vmovups 64(%rdi), %ymm10
vmovups 96(%rdi), %ymm11
addq $128, %rsi
vfmadd231ps (%rdx), %ymm0, %ymm8
vfmadd231ps 32(%rdx), %ymm1, %ymm9
vfmadd231ps 64(%rdx), %ymm2, %ymm10
vfmadd231ps 96(%rdx), %ymm3, %ymm11
vmovups %ymm8, (%rdi) // output_ptr
vmovups %ymm9, 32(%rdi)
vmovups %ymm10, 64(%rdi)
vmovups %ymm11, 96(%rdi)
addq $128, %rdi
addq $128, %rdx
subq $32, %r8
cmpq $32, %r8
jae LoopC32
cmpq $16, %r8
jae LoopC16
cmpq $8, %r8
jae LoopC8
cmpq $0, %r8
ja LoopC
jmp LoopCEnd
LoopC16:
vmovups (%rsi), %ymm0 // input_tmp
vmovups (%rdi), %ymm8 // output_tmp
vmovups 32(%rsi), %ymm1
vmovups 32(%rdi), %ymm9
addq $64, %rsi
vfmadd231ps (%rdx), %ymm0, %ymm8
vfmadd231ps 32(%rdx), %ymm1, %ymm9
vmovups %ymm8, (%rdi) // output_ptr
addq $64, %rdx
vmovups %ymm9, 32(%rdi)
addq $64, %rdi
subq $16, %r8
cmpq $16, %r8
jae LoopC16
cmpq $8, %r8
jae LoopC8
cmpq $0, %r8
ja LoopC
jmp LoopCEnd
LoopC8:
vmovups (%rsi), %ymm0 // input_tmp
vmovups (%rdi), %ymm8 // output_tmp
addq $32, %rsi
vfmadd231ps (%rdx), %ymm0, %ymm8
addq $32, %rdx
vmovups %ymm8, (%rdi)
addq $32, %rdi
subq $8, %r8
cmpq $8, %r8
jae LoopC8
cmpq $0, %r8
ja LoopC
jmp LoopCEnd
LoopC:
vmovss (%rsi), %xmm0 // input_tmp
vmovss (%rdi), %xmm8 // output_ptr
vfmadd231ss (%rdx), %xmm0, %xmm8
addq $4, %rsi
addq $4, %rdx
vmovss %xmm8, (%rdi)
addq $4, %rdi
subq $1, %r8
cmpq $0, %r8
ja LoopC
jmp LoopCEnd
LoopCEnd:
subq $1, %rcx // num_pixel -= 1
cmpq $0, %rcx
je End
addq %r9, %r13
jmp LoopPixel
End:
subq $48, %rsp
popq %rdi
popq %rsi
popq %r12
popq %r13
popq %r14
popq %r15
retq
#endif

@ -21,6 +21,20 @@
#include "nnacl/op_base.h"
#include "nnacl/conv_parameter.h"
typedef struct ConvDwFp32BorderParam {
float *dst;
const float *src;
const float *weight;
const float *bias;
size_t height;
size_t width;
size_t in_kh_step;
size_t in_kw_step;
size_t kernel_w;
size_t relu;
size_t relu6;
} ConvDwFp32BorderParam;
#ifdef __cplusplus
extern "C" {
#endif
@ -37,8 +51,12 @@ void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size
void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6);
#ifdef ENABLE_AVX
void ConvDwFp32Border(ConvDwFp32BorderParam *param);
#else
void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6);
#endif
void DeconvDwFp32Center(float *dst, const float *src, const float *weight, size_t height, size_t width, size_t kernel_h,
size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step,
size_t in_kh_step, size_t in_kw_step);

@ -202,8 +202,21 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float
const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
#ifdef ENABLE_AVX
ConvDwFp32BorderParam *param = (ConvDwFp32BorderParam *)malloc(sizeof(ConvDwFp32BorderParam));
param->dst = dst_kernel;
param->src = src_kernel;
param->weight = weight_kernel;
param->bias = bias;
param->height = end_kh - start_kh;
param->width = end_kw - start_kw;
param->in_kh_step = sliding->in_kh_step_ * sizeof(float);
param->in_kw_step = sliding->in_kw_step_ * sizeof(float);
param->kernel_w = conv_param->kernel_w_ * C4NUM * sizeof(float);
param->relu = relu;
param->relu6 = relu6;
ConvDwFp32Border(param);
#elif defined(ENABLE_ARM) || defined(ENABLE_SSE)
ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);

@ -14,7 +14,7 @@
* limitations under the License.
*/
#ifdef ENABLE_SSE
#if defined(ENABLE_SSE) && !defined(ENABLE_AVX)
#include <x86intrin.h>
#include "nnacl/fp32/common_func_fp32.h"

@ -19,6 +19,7 @@
#include "nnacl/fp32/conv_depthwise_fp32.h"
#include "nnacl/intrinsics/sse/sse_common.h"
#ifndef ENABLE_AVX
void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t in_kh_step, size_t in_kw_step, size_t kernel_w_step, size_t relu, size_t relu6) {
in_kh_step /= sizeof(float);
@ -104,6 +105,7 @@ void ConvDwFp32Border(float *dst, const float *src, const float *weight, const f
}
_mm_storeu_ps(dst, dst_ma);
}
#endif
void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,

Loading…
Cancel
Save